From 38072649ec66d73506470e2bcef9f50be6acc121 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Fri, 23 Nov 2018 09:10:24 +0100 Subject: [PATCH] [skip ci] Improve sumfact kernel interface Introduce different methods for realize_input/output realize_direct_input/output and setup_input/output. The setup methods cover code generation outside the sumfact kernel function (creating input array or accumulating result). realize and realize_direct handle the input/output in the nonfastdg and fastdg code branch. Seperate interface methods make it a lot easier to find out where each of those methods will be applied. Besides that most interface classes need to provide more that two of those methods anyway... --- python/dune/codegen/sumfact/accumulation.py | 41 +++--- python/dune/codegen/sumfact/basis.py | 36 ++--- python/dune/codegen/sumfact/geometry.py | 2 +- python/dune/codegen/sumfact/realization.py | 41 ++---- python/dune/codegen/sumfact/symbolic.py | 155 ++++++++++++-------- 5 files changed, 149 insertions(+), 126 deletions(-) diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py index b43d7906..eb44e3cb 100644 --- a/python/dune/codegen/sumfact/accumulation.py +++ b/python/dune/codegen/sumfact/accumulation.py @@ -154,6 +154,25 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): from dune.codegen.sumfact.basis import lfs_inames return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: This should happen in stage 2 and not in stage 3 + shape = permute_backward(shape, self.cost_permutation) + inames = permute_backward(inames, self.cost_permutation) + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None @@ -213,6 +232,9 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): return frozenset({dep}) def realize_direct_output(self, result, inames, shape, which=0, **args): + inames = permute_backward(inames, self.cost_permutation) + inames = permute_backward(inames, self.quadrature_permutation) + direct_output = "fastdg{}".format(which) ftags = ",".join(["f"] * len(shape)) @@ -241,25 +263,6 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord): tags=frozenset({"sumfact_stage3"}), **args)}) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: This should happen in stage 2 and not in stage 3 - shape = permute_backward(shape, self.cost_permutation) - inames = permute_backward(inames, self.cost_permutation) - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_name_suffix(self): if get_form_option("fastdg"): diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py index 78bf364b..a856c5a9 100644 --- a/python/dune/codegen/sumfact/basis.py +++ b/python/dune/codegen/sumfact/basis.py @@ -161,6 +161,24 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return insn_dep.union(frozenset({insn})) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # Note: Here we do not need to reverse any permutation since this is + # already done in the setup_input method above! + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def realize_direct_input(self, shape, inames, which=0): # If the input comes directly from a global data structure inames are # ordered x,y,z,... @@ -183,24 +201,6 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord): return prim.Subscript(prim.Variable(arg), inames) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # Note: Here we do not need to reverse any permutation since this is - # already done in the setup_input method above! - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_name_suffix(self): if get_form_option("fastdg"): diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py index 86528227..8b714a41 100644 --- a/python/dune/codegen/sumfact/geometry.py +++ b/python/dune/codegen/sumfact/geometry.py @@ -172,7 +172,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord): return insn_dep.union(frozenset({insn})) - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): # Get a temporary that interprets the base storage of the input # as a column-major matrix. In later iteration of the matrix loop # this reinterprets the output of the previous iteration. diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py index 73adc444..d1f10d6f 100644 --- a/python/dune/codegen/sumfact/realization.py +++ b/python/dune/codegen/sumfact/realization.py @@ -195,10 +195,10 @@ def realize_sumfact_kernel_function(sf): input_summand = sf.interface.realize_direct_input(inp_shape, input_inames) elif l == 0: # TODO: Simplify arguments! - input_summand = sf.interface.realize_input(inp_shape, - input_inames, - vec_shape, + input_summand = sf.interface.realize_input(input_inames, + inp_shape, vec_iname, + vec_shape, buffer, ftags, l, @@ -238,36 +238,21 @@ def realize_sumfact_kernel_function(sf): # In case of direct output we directly accumulate the result # of the Sumfactorization into some global data structure. if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3: - # TODO: Move permutations to interface! - output_inames = permute_backward(output_inames, sf.cost_permutation) - output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) - if sf.vectorized: insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name}) insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args) elif l == len(matrix_sequence) - 1: - # TODO: Move permutations to interface! - output_inames = permute_backward(output_inames, sf.cost_permutation) output_shape = tuple(out_shape[1:]) + (out_shape[0],) - - # TODO: Move permutations to interface - output_shape = permute_backward(output_shape, sf.cost_permutation) - if sf.stage == 3: - output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation) - output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation) - - out = buffer.get_temporary("buff_step{}_out".format(l), - shape=output_shape + vec_shape, - dim_tags=ftags, - ) - - # Issue the reduction instruction that implements the multiplication - # at the same time store the instruction ID for the next instruction to depend on - insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname), - expression=matprod, - **insn_args - ) - }) + insn_dep = sf.interface.realize_output(matprod, + output_inames, + output_shape, + vec_iname, + vec_shape, + buffer, + ftags, + l, + **insn_args, + ) else: output_shape = tuple(out_shape[1:]) + (out_shape[0],) out = buffer.get_temporary("buff_step{}_out".format(l), diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 440de43a..8c7052ad 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -2,6 +2,7 @@ from dune.codegen.options import get_form_option, get_option from dune.codegen.generation import (get_counted_variable, + instruction, silenced_warning, subst_rule, transform, @@ -43,12 +44,12 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): """Interpret the input of sumfact kernel function in the right way (non fastdgg) This happens inside the sumfact kernel function. - TODO: Cleanup input + TODO: Cleanup arguments TODO: Add note about permutation TODO: Document input arguments """ @@ -64,24 +65,52 @@ class SumfactKernelInterfaceBase(object): """ raise NotImplementedError - def realize_direct_output(self, result, iname, shape, which=0, **args): - """Accumulate results directly in the sumfact kernel function (fastdg) + def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): + """Generate accumulate instruction after sumfact kernel function (non fastdg) + + This happens after the function call. + + TODO: Add note about permutation + TODO: Document input arguments + """ + raise NotImplementedError + + def realize_output(self, result, inames, shape, vec_iname, vec_shape, buffer, ftags, l, **args): + """Handle the output of the last tensor contraction in the sumfact kernel function the right way This happens inside the sumfact kernel function. + TODO: Cleanup arguments TODO: Add note about permutation TODO: Document input arguments """ + inames = permute_backward(inames, self.cost_permutation) + shape = permute_backward(shape, self.cost_permutation) + if self.stage == 3: + inames = permute_backward(inames, self.quadrature_permutation) + shape = permute_backward(shape, self.quadrature_permutation) - def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()): - """Generate accumulate instruction after sumfact kernel function (non fastdg) + out = buffer.get_temporary("buff_step{}_out".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) - This happens after the function call. + # Issue the reduction instruction that implements the multiplication + # at the same time store the instruction ID for the next instruction to depend on + return frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), inames + vec_iname), + expression=result, + **args + ) + }) + + def realize_direct_output(self, result, iname, shape, which=0, **args): + """Accumulate results directly in the sumfact kernel function (fastdg) + + This happens inside the sumfact kernel function. TODO: Add note about permutation TODO: Document input arguments """ - raise NotImplementedError @property def quadrature_permutation(self): @@ -178,7 +207,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): for i in self.interfaces: assert i.cost_permutation == cost_permutation - return vector_cost_permutation + return self.vector_cost_permutation @property def stage(self): @@ -193,6 +222,23 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): dep = dep.union(inp.setup_input(sf, dep, index=i)) return dep + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: vector_cost_permutation not used! + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) + + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + return prim.Subscript(prim.Variable(inp), inames + vec_iname) + def realize_direct_input(self, shape, inames): # TODO: vector_cost_permutation not used! @@ -222,23 +268,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): # need to load scalars into the SIMD vector. raise NotImplementedError("SIMD loads from scalars not implemented!") - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: vector_cost_permutation not used! - - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) - - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) - - return prim.Subscript(prim.Variable(inp), inames + vec_iname) - @property def function_args(self): return sum((i.function_args for i in remove_duplicates(self.interfaces)), ()) @@ -262,11 +291,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase): class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): def __init__(self, interfaces, perm): self.interfaces = interfaces - self.vector_cost_permutation = perm + self._cost_permutation = perm def __repr__(self): return "_".join(repr(o) for o in self.interfaces) + @property + def cost_permutation(self): + return self._cost_permutation + @property def quadrature_permutation(self): # TODO: For now we assure that all kerneles have the same quadrature_permutation @@ -293,29 +326,29 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): return prim.Call(prim.Variable(hadd_function), (result,)) - def setup_output(self, sf, result, insn_dep): - # TODO: vector_cost_permutation not used! - - outputs = set(self.interfaces) + def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l): + # TODO: Include permutations of scalar kernels as soon as they could be different + shape = permute_backward(shape, self.cost_permutation) + inames = permute_backward(inames, self.cost_permutation) - trial_element, = set(o.trial_element for o in self.interfaces) - trial_element_index = set(o.trial_element_index for o in self.interfaces).pop() - from dune.codegen.sumfact.accumulation import accum_iname - element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None - inames = tuple(accum_iname(element, mat.rows, i) - for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted)) - veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") - transform(lp.tag_inames, [(veciname, "vec")]) + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the matrix loop + # this reinterprets the output of the previous iteration. + inp = buffer.get_temporary("buff_step{}_in".format(l), + shape=shape + vec_shape, + dim_tags=ftags, + ) - deps = frozenset() - for o in outputs: - hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,)))) - deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,))) + # The input temporary will only be read from, so we need to silence + # the loopy warning + silenced_warning('read_no_write({})'.format(inp)) - return deps + return prim.Subscript(prim.Variable(inp), inames + vec_iname) def realize_direct_output(self, result, inames, shape, **args): - # TODO: vector_cost_permutation not used! + # TODO: Find out what needs to happen here + # inames = permute_backward(inames, self.cost_permutation) + # shape = permute_backward(shape, self.cost_permutation) outputs = set(self.interfaces) @@ -335,24 +368,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase): return deps - def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l): - # TODO: Include permutations of scalar kernels as soon as they could be different - shape = permute_backward(shape, self.vector_cost_permutation) - inames = permute_backward(inames, self.vector_cost_permutation) + def setup_output(self, sf, result, insn_dep): + # TODO: vector_cost_permutation not used! - # Get a temporary that interprets the base storage of the input - # as a column-major matrix. In later iteration of the matrix loop - # this reinterprets the output of the previous iteration. - inp = buffer.get_temporary("buff_step{}_in".format(l), - shape=shape + vec_shape, - dim_tags=ftags, - ) + outputs = set(self.interfaces) - # The input temporary will only be read from, so we need to silence - # the loopy warning - silenced_warning('read_no_write({})'.format(inp)) + trial_element, = set(o.trial_element for o in self.interfaces) + trial_element_index = set(o.trial_element_index for o in self.interfaces).pop() + from dune.codegen.sumfact.accumulation import accum_iname + element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None + inames = tuple(accum_iname(element, mat.rows, i) + for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted)) + veciname = accum_iname(element, sf.vector_width // len(outputs), "vec") + transform(lp.tag_inames, [(veciname, "vec")]) - return prim.Subscript(prim.Variable(inp), inames + vec_iname) + deps = frozenset() + for o in outputs: + hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,)))) + deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,))) + + return deps @property def function_args(self): -- GitLab