diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index b43d7906a2802215b394f728a2f466f55018748a..eb44e3cb0967e932d07d3fb98b5480882d58b262 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -154,6 +154,25 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): from dune.codegen.sumfact.basis import lfs_inames return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: This should happen in stage 2 and not in stage 3 + shape = permute_backward(shape, self.cost_permutation) + inames = permute_backward(inames, self.cost_permutation) + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None @@ -213,6 +232,9 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): return frozenset({dep}) def realize_direct_output(self, result, inames, shape, which=0, **args): + inames = permute_backward(inames, self.cost_permutation) + inames = permute_backward(inames, self.quadrature_permutation) + direct_output = "fastdg{}".format(which) ftags = ",".join(["f"] * len(shape)) @@ -241,25 +263,6 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): tags=frozenset({"sumfact_stage3"}), **args)}) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: This should happen in stage 2 and not in stage 3 - shape = permute_backward(shape, self.cost_permutation) - inames = permute_backward(inames, self.cost_permutation) - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_name_suffix(self): if get_form_option("fastdg"): diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index 78bf364b633258c4b5eb2f09a8634398c6cac6ad..a856c5a95c496c4b18ffe0b8d848003799d236d0 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -161,6 +161,24 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return insn_dep.union(frozenset({insn})) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # Note: Here we do not need to reverse any permutation since this is + # already done in the setup_input method above! + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def realize_direct_input(self, shape, inames, which=0): # If the input comes directly from a global data structure inames are # ordered x,y,z,... @@ -183,24 +201,6 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return prim.Subscript(prim.Variable(arg), inames) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # Note: Here we do not need to reverse any permutation since this is - # already done in the setup_input method above! - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_name_suffix(self): if get_form_option("fastdg"): diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index 8652822725a24c53816ca7fe5cc20752ceb37bf3..8b714a419e69d65ad9d59dda732e93577bc1ae5a 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -172,7 +172,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): return insn_dep.union(frozenset({insn})) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): # Get a temporary that interprets the base storage of the input # as a column-major matrix. In later iteration of the matrix loop # this reinterprets the output of the previous iteration. diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 73adc444852173fe49afb3639232baf05af4bd75..d1f10d6f9c659fd8c719e19a59658245e5c27b39 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -195,10 +195,10 @@ def realize_sumfact_kernel_function(sf): input_summand = sf.interface.realize_direct_input(inp_shape, input_inames) elif l == 0: # TODO: Simplify arguments! - input_summand = sf.interface.realize_input(inp_shape, - input_inames, - vec_shape, + input_summand = sf.interface.realize_input(input_inames, + inp_shape, vec_iname, + vec_shape, buffer, ftags, l, @@ -238,36 +238,21 @@ def realize_sumfact_kernel_function(sf): # In case of direct output we directly accumulate the result # of the Sumfactorization into some global data structure. if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: - # TODO: Move permutations to interface! - output_inames = permute_backward(output_inames, sf.cost_permutation) - output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) - if sf.vectorized: insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args) elif l == len(matrix_sequence) - 1: - # TODO: Move permutations to interface! - output_inames = permute_backward(output_inames, sf.cost_permutation) output_shape = tuple(out_shape[1:]) + (out_shape[0],) - - # TODO: Move permutations to interface - output_shape = permute_backward(output_shape, sf.cost_permutation) - if sf.stage == 3: - output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) - output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) - - out = buffer.get_temporary("buff_step{}_out".format(l), - shape=output_shape + vec_shape, - dim_tags=ftags, - ) - - # Issue the reduction instruction that implements the multiplication - # at the same time store the instruction ID for the next instruction to depend on - insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname), - expression=matprod, - **insn_args - ) - }) + insn_dep = sf.interface.realize_output(matprod, + output_inames, + output_shape, + vec_iname, + vec_shape, + buffer, + ftags, + l, + **insn_args, + ) else: output_shape = tuple(out_shape[1:]) + (out_shape[0],) out = buffer.get_temporary("buff_step{}_out".format(l), diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 440de43a9704c654de165b786fac6a288b8fa26d..8c7052add83dcfb5ba4ff4840b8ad57e11c94a72 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -2,6 +2,7 @@ from dune.codegen.options import get_form_option, get_option from dune.codegen.generation import (get_counted_variable, + instruction, silenced_warning, subst_rule, transform, @@ -43,12 +44,12 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): """Interpret the input of sumfact kernel function in the right way (non fastdgg) This happens inside the sumfact kernel function. - TODO: Cleanup input + TODO: Cleanup arguments TODO: Add note about permutation TODO: Document input arguments """ @@ -64,24 +65,52 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_direct_output(self, result, iname, shape, which=0, **args): - """Accumulate results directly in the sumfact kernel function (fastdg) + def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): + """Generate accumulate instruction after sumfact kernel function (non fastdg) + + This happens after the function call. + + TODO: Add note about permutation + TODO: Document input arguments + """ + raise NotImplementedError + + def realize_output(self, result, inames, shape, vec_iname, vec_shape, buffer, ftags, l, **args): + """Handle the output of the last tensor contraction in the sumfact kernel function the right way This happens inside the sumfact kernel function. + TODO: Cleanup arguments TODO: Add note about permutation TODO: Document input arguments """ + inames = permute_backward(inames, self.cost_permutation) + shape = permute_backward(shape, self.cost_permutation) + if self.stage == 3: + inames = permute_backward(inames, self.quadrature_permutation) + shape = permute_backward(shape, self.quadrature_permutation) - def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): - """Generate accumulate instruction after sumfact kernel function (non fastdg) + out = buffer.get_temporary("buff_step{}_out".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) - This happens after the function call. + # Issue the reduction instruction that implements the multiplication + # at the same time store the instruction ID for the next instruction to depend on + return frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), inames + vec_iname), + expression=result, + **args + ) + }) + + def realize_direct_output(self, result, iname, shape, which=0, **args): + """Accumulate results directly in the sumfact kernel function (fastdg) + + This happens inside the sumfact kernel function. TODO: Add note about permutation TODO: Document input arguments """ - raise NotImplementedError @property def quadrature_permutation(self): @@ -178,7 +207,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): for i in self.interfaces: assert i.cost_permutation == cost_permutation - return vector_cost_permutation + return self.vector_cost_permutation @property def stage(self): @@ -193,6 +222,23 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): dep = dep.union(inp.setup_input(sf, dep, index=i)) return dep + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: vector_cost_permutation not used! + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def realize_direct_input(self, shape, inames): # TODO: vector_cost_permutation not used! @@ -222,23 +268,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): # need to load scalars into the SIMD vector. raise NotImplementedError("SIMD loads from scalars not implemented!") - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: vector_cost_permutation not used! - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_args(self): return sum((i.function_args for i in remove_duplicates(self.interfaces)), ()) @@ -262,11 +291,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): def __init__(self, interfaces, perm): self.interfaces = interfaces - self.vector_cost_permutation = perm + self._cost_permutation = perm def __repr__(self): return "_".join(repr(o) for o in self.interfaces) + @property + def cost_permutation(self): + return self._cost_permutation + @property def quadrature_permutation(self): # TODO: For now we assure that all kerneles have the same quadrature_permutation @@ -293,29 +326,29 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): return prim.Call(prim.Variable(hadd_function), (result,)) - def setup_output(self, sf, result, insn_dep): - # TODO: vector_cost_permutation not used! - - outputs = set(self.interfaces) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: Include permutations of scalar kernels as soon as they could be different + shape = permute_backward(shape, self.cost_permutation) + inames = permute_backward(inames, self.cost_permutation) - trial_element, = set(o.trial_element for o in self.interfaces) - trial_element_index = set(o.trial_element_index for o in self.interfaces).pop() - from dune.codegen.sumfact.accumulation import accum_iname - element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None - inames = tuple(accum_iname(element, mat.rows, i) - for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted)) - veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") - transform(lp.tag_inames, [(veciname, "vec")]) + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) - deps = frozenset() - for o in outputs: - hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,)))) - deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,))) + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) - return deps + return prim.Subscript(prim.Variable(inp), inames + vec_iname) def realize_direct_output(self, result, inames, shape, **args): - # TODO: vector_cost_permutation not used! + # TODO: Find out what needs to happen here + # inames = permute_backward(inames, self.cost_permutation) + # shape = permute_backward(shape, self.cost_permutation) outputs = set(self.interfaces) @@ -335,24 +368,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): return deps - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: Include permutations of scalar kernels as soon as they could be different - shape = permute_backward(shape, self.vector_cost_permutation) - inames = permute_backward(inames, self.vector_cost_permutation) + def setup_output(self, sf, result, insn_dep): + # TODO: vector_cost_permutation not used! - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) + outputs = set(self.interfaces) - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) + trial_element, = set(o.trial_element for o in self.interfaces) + trial_element_index = set(o.trial_element_index for o in self.interfaces).pop() + from dune.codegen.sumfact.accumulation import accum_iname + element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None + inames = tuple(accum_iname(element, mat.rows, i) + for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted)) + veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") + transform(lp.tag_inames, [(veciname, "vec")]) - return prim.Subscript(prim.Variable(inp), inames + vec_iname) + deps = frozenset() + for o in outputs: + hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,)))) + deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,))) + + return deps @property def function_args(self):