diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index e258f7adfa9acb4a9d6e841a2ed863b20a76cb5f..e9abf7dc036262c4047eb1d03506536921d36364 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -172,28 +172,18 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): return frozenset({dep}) - def realize_direct(self, result, inames, shape, args): - ft = get_global_context_value("form_type") - - if self.test_element_index is None: - direct_output = "{}_access".format(self.accumvar) - else: - direct_output = "{}_access_comp{}".format(self.accumvar, self.test_element_index) + def realize_direct(self, result, inames, shape, which=0, **args): + direct_output = "fastdg{}".format(which) ftags = ",".join(["f"] * len(shape)) - from dune.perftool.sumfact.realization import alias_data_array - if ft == 'residual' or ft == 'jacobian_apply': + if self.trial_element is None: globalarg(direct_output, shape=shape, dim_tags=ftags, offset=_dof_offset(self.test_element, self.test_element_index), ) - alias_data_array(direct_output, self.accumvar) lhs = prim.Subscript(prim.Variable(direct_output), inames) else: - assert ft == 'jacobian' - - direct_output = "{}x{}".format(direct_output, self.trial_element_index) rowsize = sum(tuple(s for s in _local_sizes(self.trial_element))) element = get_leaf(self.trial_element, self.trial_element_index) other_shape = tuple(element.degree() + 1 for e in shape) @@ -205,7 +195,6 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): offset=rowsize * _dof_offset(self.test_element, self.test_element_index) + _dof_offset(self.trial_element, self.trial_element_index), dim_tags=dim_tags, ) - alias_data_array(direct_output, self.accumvar) # TODO: It is at least questionnable, whether using the *order* of the inames in here # for indexing is a good idea. Then again, it is hard to find an alternative. _ansatz_inames = tuple(prim.Variable(i) for i in self.within_inames) @@ -217,6 +206,10 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord): tags=frozenset({"sumfact_stage3"}), **args)}) + @property + def fastdg_args(self): + return ("{}.data()".format(self.accumvar),) + def _local_sizes(element): from ufl import FiniteElement, MixedElement diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index ee2f068f9eb04f09af174944ee81574acf8d7014..4e53ec4b493948ab5c44a5ac3f9a928f46b3374d 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -100,10 +100,8 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): return insn_dep.union(frozenset({insn})) - def realize_direct(self, shape, inames): - arg = "{}_access{}".format(self.coeff_func(self.restriction), - "_comp{}".format(self.element_index) if self.element_index else "" - ) + def realize_direct(self, shape, inames, which=0): + arg = "fastdg{}".format(which) from dune.perftool.sumfact.accumulation import _dof_offset from dune.perftool.sumfact.realization import alias_data_array @@ -113,11 +111,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): offset=_dof_offset(self.element, self.element_index), ) - func = self.coeff_func(self.restriction) - alias_data_array(arg, func) - return prim.Subscript(prim.Variable(arg), inames) + @property + def fastdg_args(self): + func = self.coeff_func(self.restriction) + return ("{}.data()".format(func),) + def _basis_functions_per_direction(element): """Number of basis functions per direction """ diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index dfe080284986f20c1bc7fd641b8cf6b4345d2997..de515848d8a50374af720862084eace1459d04a4 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -27,6 +27,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) +from dune.perftool.sumfact.symbolic import get_input_output_tuple from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.target import dtype_floatingpoint @@ -125,8 +126,15 @@ def _realize_sum_factorization_kernel(sf): if not sf.input.direct_input_is_possible: insn_dep = insn_dep.union(sf.input.realize(sf, insn_dep)) + # Collect function call arguments + fastdg_args = () + if sf.stage == 1: + fastdg_args = sf.input.fastdg_args + if sf.stage == 3: + fastdg_args = sf.output.fastdg_args + # Call the function - code = "{}({}, {});".format(funcname, *buffers) + code = "{}({});".format(funcname, ", ".join(buffers + fastdg_args)) tag = "sumfact_stage{}".format(sf.stage) insn_dep = frozenset({instruction(code=code, depends_on=insn_dep, @@ -302,8 +310,8 @@ def realize_sumfact_kernel_function(sf): # of the Sumfactorization into some global data structure. if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: if sf.vectorized: - insn_args["forced_iname_deps"] = insn_args["forced_iname_deps"].union(frozenset({vec_iname[0].name})) - insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, insn_args) + insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) + insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, **insn_args) else: # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on @@ -316,7 +324,12 @@ def realize_sumfact_kernel_function(sf): # Construct a loopy kernel object name = name_kernel_implementation_function(sf) from dune.perftool.pdelab.localoperator import extract_kernel_from_cache - signature = "void {}(const char* buffer0, const char* buffer1) const".format(name) + args = ["const char* buffer0", "const char* buffer1"] + if get_form_option('fastdg'): + const = "const " if sf.stage == 1 else "" + args = args + ["{}double* fastdg{}".format(const, i) for i in range(len(get_input_output_tuple(sf)))] + + signature = "void {}({}) const".format(name, ", ".join(args)) kernel = extract_kernel_from_cache("kernel_default", name, [signature], add_timings=False) delete_cache_items("kernel_default") return kernel diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 83638f179fd65c52bc0f53670016c3c3f957743f..5dc395f8a8c7f90febab6a00d4f62eb7a504addd 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -73,7 +73,7 @@ class VectorSumfactKernelInput(SumfactKernelInputBase): # from two shorter SIMD types return prim.Call(VCLLowerUpperLoad(dtype_floatingpoint()), (self.inputs[0].realize_direct(shape, inames), - self.inputs[len(self.inputs) // 2].realize_direct(shape, inames), + self.inputs[len(self.inputs) // 2].realize_direct(shape, inames, which=1), ) ) else: @@ -134,7 +134,7 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): return deps - def realize_direct(self, result, inames, shape, args): + def realize_direct(self, result, inames, shape, **args): outputs = set(self.outputs) # If multiple horizontal_add's are to be performed with 'result' @@ -148,7 +148,7 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase): deps = frozenset() for o in outputs: hadd_result = self._add_hadd(o, result) - deps = deps.union(o.realize_direct(hadd_result, inames, shape, args)) + deps = deps.union(o.realize_direct(hadd_result, inames, shape, which=self.outputs.index(o), **args)) return deps @@ -761,3 +761,21 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) to be carried out """ from dune.perftool.sumfact.permutation import flop_cost return flop_cost(self.matrix_sequence) + + +def get_input_output_tuple(sf): + if sf.stage == 1: + if isinstance(sf, SumfactKernel): + return (sf.input,) + if isinstance(sf, VectorizedSumfactKernel): + # This is a short recipe for removing duplicates from an iterable + # while preserving the order! + seen = set() + return tuple(x for x in self.input if not (x in seen or seen.add(x))) + if sf.stage == 3: + if isinstance(sf, SumfactKernel): + return (sf.output,) + if isinstance(sf, VectorizedSumfactKernel): + seen = set() + return tuple(x for x in self.output if not (x in seen or seen.add(x))) + assert(False)