diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index 6957b408366861b0223287f0c7a59cd08d996cc2..dd7605084f960f5b2621eeafe34cf7b0a6da671c 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -186,10 +186,11 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): else: rowsize = sum(tuple(s for s in _local_sizes(self.trial_element))) manual_strides = tuple("stride:{}".format(rowsize * product(shape[:i])) for i in range(len(shape))) - valuearg("jacobian_offset") + offset = "jacobian_offset{}".format(which) + valuearg(offset) globalarg(direct_output, shape=shape, - offset=prim.Variable("jacobian_offset") + rowsize * _dof_offset(self.test_element, self.test_element_index) + _dof_offset(self.trial_element, self.trial_element_index), + offset=prim.Variable(offset) + rowsize * _dof_offset(self.test_element, self.test_element_index) + _dof_offset(self.trial_element, self.trial_element_index), dim_tags=manual_strides, ) lhs = prim.Subscript(prim.Variable(direct_output), inames) @@ -204,10 +205,10 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): def fastdg_args(self): if get_form_option("fastdg"): ret = ("{}.data()".format(self.accumvar),) - if get_form_option("fastdg") and sf.within_inames: - element = get_leaf(sf.output.trial_element, sf.output.trial_element_index) + if get_form_option("fastdg") and self.within_inames: + element = get_leaf(self.trial_element, self.trial_element_index) shape = tuple(element.degree() + 1 for e in range(element.cell().geometric_dimension())) - jacobian_index = flatten_index(tuple(prim.Variable(i) for i in sf.within_inames), shape, order="f") + jacobian_index = flatten_index(tuple(prim.Variable(i) for i in self.within_inames), shape, order="f") ret = ret + (str(jacobian_index),) return ret else: diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 1a7188e9d5fba39af52d05afc0ef4180d71d2335..ee7ff4d74f94617c4df158895377132d4ff632ae 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -37,7 +37,7 @@ from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import ExplicitVCLCast -from dune.perftool.tools import get_leaf +from dune.perftool.tools import get_leaf, remove_duplicates from pytools import product from ufl import MixedElement @@ -58,14 +58,19 @@ def _name_kernel_implementation_function(sf, qp): if isinstance(sf, SumfactKernel): fastdg = "{}comp{}".format(FEM_name_mangling(sf.input.element), sf.input.element_index) if isinstance(sf, VectorizedSumfactKernel): - 1/0 + fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.element), i.element_index) for i in remove_duplicates(sf.input.inputs)) if sf.stage == 3: if isinstance(sf, SumfactKernel): fastdg = "{}comp{}".format(FEM_name_mangling(sf.output.test_element), sf.output.test_element_index) - if sf.output.trial_element: + if sf.within_inames: fastdg = "{}x{}comp{}".format(fastdg, FEM_name_mangling(sf.output.trial_element), sf.output.trial_element_index) if isinstance(sf, VectorizedSumfactKernel): - 1/0 + fastdg = "_".join("{}comp{}".format(FEM_name_mangling(i.test_element), i.test_element_index) for i in remove_duplicates(sf.output.outputs)) + if sf.within_inames: + fastdg = "{}x{}".format(fastdg, + "_".join("{}comp{}".format(FEM_name_mangling(i.trial_element), i.trial_element_index) for i in remove_duplicates(sf.output.outputs)) + ) + name = "{}_fastdg{}_{}".format(name, sf.stage, fastdg) necessary_kernel_implementations((sf, qp)) return name @@ -347,9 +352,10 @@ def realize_sumfact_kernel_function(sf): args = ["const char* buffer0", "const char* buffer1"] if get_form_option('fastdg'): const = "const " if sf.stage == 1 else "" - args = args + ["{}double* fastdg{}".format(const, i) for i in range(len(get_input_output_tuple(sf)))] - if sf.within_inames: - args = args + ["unsigned int jacobian_offset"] + for i in range(len(get_input_output_tuple(sf))): + args.append("{}double* fastdg{}".format(const, i)) + if sf.within_inames: + args.append("unsigned int jacobian_offset{}".format(i)) signature = "void {}({}) const".format(name, ", ".join(args)) kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 45d6d0cf516d3a9be00c9f51bf4a07f317a32681..c3be20ddbb2266c439c0e6513665176f59b15a18 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -10,7 +10,7 @@ from dune.perftool.sumfact.quadrature import quadrature_inames from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from dune.perftool.loopy.target import dtype_floatingpoint from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad -from dune.perftool.tools import get_leaf, maybe_wrap_subscript +from dune.perftool.tools import get_leaf, maybe_wrap_subscript, remove_duplicates from pytools import ImmutableRecord, product @@ -81,6 +81,10 @@ class VectorSumfactKernelInput(SumfactKernelInputBase): # need to load scalars into the SIMD vector. raise NotImplementedError("SIMD loads from scalars not implemented!") + @property + def fastdg_args(self): + return sum((i.fastdg_args for i in remove_duplicates(self.inputs)), ()) + class SumfactKernelOutputBase(object): @property @@ -143,15 +147,20 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): substname = "haddsubst_{}".format("_".join([i.name for i in inames])) subst_rule(substname, (), result) result = prim.Call(prim.Variable(substname), ()) - transform(lp.precompute, substname, precompute_outer_inames=args["forced_iname_deps"]) + transform(lp.precompute, substname) deps = frozenset() for o in outputs: hadd_result = self._add_hadd(o, result) - deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=self.outputs.index(o), **args)) + which = tuple(remove_duplicates(self.outputs)).index(o) + deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=which, **args)) return deps + @property + def fastdg_args(self): + return sum((i.fastdg_args for i in remove_duplicates(self.outputs)), ()) + class SumfactKernelBase(object): pass @@ -548,7 +557,12 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # @property def function_key(self): - fastdg = self.inout_key if get_form_option("fastdg") else () + fastdg = () + if get_form_option("fastdg"): + if self.stage == 1: + fastdg = sum(((i.element, i.element_index) for i in remove_duplicates(self.input.inputs)), ()) + if self.stage == 3: + fastdg = sum(((o.test_element, o.test_element_index, o.trial_element, o.trial_element_index) for o in remove_duplicates(self.output.outputs)), ()) return tuple(str(m) for m in self.matrix_sequence) + fastdg @property @@ -775,14 +789,10 @@ def get_input_output_tuple(sf): if isinstance(sf, SumfactKernel): return (sf.input,) if isinstance(sf, VectorizedSumfactKernel): - # This is a short recipe for removing duplicates from an iterable - # while preserving the order! - seen = set() - return tuple(x for x in self.input if not (x in seen or seen.add(x))) + return tuple(remove_duplicates(sf.input.inputs)) if sf.stage == 3: if isinstance(sf, SumfactKernel): return (sf.output,) if isinstance(sf, VectorizedSumfactKernel): - seen = set() - return tuple(x for x in self.output if not (x in seen or seen.add(x))) + return tuple(remove_duplicates(sf.output.outputs)) assert(False) diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py index b29ebe221387a5af529f0dec29592b4de8dc762a..4a95dbced5e24d7efd0ddadc422511a1fc1362cc 100644 --- a/python/dune/perftool/tools.py +++ b/python/dune/perftool/tools.py @@ -93,3 +93,12 @@ def get_leaf(element, index): leaf_element = element.extract_component(index)[1] return leaf_element + + +def remove_duplicates(iterable): + """ Remove duplicates from an iterable while preserving the order """ + seen = set() + for i in iterable: + if i not in seen: + yield i + seen.add(i)