From 38072649ec66d73506470e2bcef9f50be6acc121 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Fri, 23 Nov 2018 09:10:24 +0100
Subject: [PATCH] [skip ci] Improve sumfact kernel interface

Introduce different methods for realize_input/output
realize_direct_input/output and setup_input/output. The setup methods cover
code generation outside the sumfact kernel function (creating input array or
accumulating result). realize and realize_direct handle the input/output in the
nonfastdg and fastdg code branch.

Seperate interface methods make it a lot easier to find out where each of those
methods will be applied. Besides that most interface classes need to provide
more that two of those methods anyway...
---
 python/dune/codegen/sumfact/accumulation.py |  41 +++---
 python/dune/codegen/sumfact/basis.py        |  36 ++---
 python/dune/codegen/sumfact/geometry.py     |   2 +-
 python/dune/codegen/sumfact/realization.py  |  41 ++----
 python/dune/codegen/sumfact/symbolic.py     | 155 ++++++++++++--------
 5 files changed, 149 insertions(+), 126 deletions(-)

diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index b43d7906..eb44e3cb 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -154,6 +154,25 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
             from dune.codegen.sumfact.basis import lfs_inames
             return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
 
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
+        # TODO: This should happen in stage 2 and not in stage 3
+        shape = permute_backward(shape, self.cost_permutation)
+        inames = permute_backward(inames, self.cost_permutation)
+
+        # Get a temporary that interprets the base storage of the input
+        # as a column-major matrix. In later iteration of the matrix loop
+        # this reinterprets the output of the previous iteration.
+        inp = buffer.get_temporary("buff_step{}_in".format(l),
+                                   shape=shape + vec_shape,
+                                   dim_tags=ftags,
+                                   )
+
+        # The input temporary will only be read from, so we need to silence
+        # the loopy warning
+        silenced_warning('read_no_write({})'.format(inp))
+
+        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
+
     def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
         trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None
 
@@ -213,6 +232,9 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
         return frozenset({dep})
 
     def realize_direct_output(self, result, inames, shape, which=0, **args):
+        inames = permute_backward(inames, self.cost_permutation)
+        inames = permute_backward(inames, self.quadrature_permutation)
+
         direct_output = "fastdg{}".format(which)
         ftags = ",".join(["f"] * len(shape))
 
@@ -241,25 +263,6 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
                                       tags=frozenset({"sumfact_stage3"}),
                                       **args)})
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
-        # TODO: This should happen in stage 2 and not in stage 3
-        shape = permute_backward(shape, self.cost_permutation)
-        inames = permute_backward(inames, self.cost_permutation)
-
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the matrix loop
-        # this reinterprets the output of the previous iteration.
-        inp = buffer.get_temporary("buff_step{}_in".format(l),
-                                   shape=shape + vec_shape,
-                                   dim_tags=ftags,
-                                   )
-
-        # The input temporary will only be read from, so we need to silence
-        # the loopy warning
-        silenced_warning('read_no_write({})'.format(inp))
-
-        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
-
     @property
     def function_name_suffix(self):
         if get_form_option("fastdg"):
diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index 78bf364b..a856c5a9 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -161,6 +161,24 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
 
         return insn_dep.union(frozenset({insn}))
 
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
+        # Note: Here we do not need to reverse any permutation since this is
+        # already done in the setup_input method above!
+
+        # Get a temporary that interprets the base storage of the input
+        # as a column-major matrix. In later iteration of the matrix loop
+        # this reinterprets the output of the previous iteration.
+        inp = buffer.get_temporary("buff_step{}_in".format(l),
+                                   shape=shape + vec_shape,
+                                   dim_tags=ftags,
+                                   )
+
+        # The input temporary will only be read from, so we need to silence
+        # the loopy warning
+        silenced_warning('read_no_write({})'.format(inp))
+
+        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
+
     def realize_direct_input(self, shape, inames, which=0):
         # If the input comes directly from a global data structure inames are
         # ordered x,y,z,...
@@ -183,24 +201,6 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
 
         return prim.Subscript(prim.Variable(arg), inames)
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
-        # Note: Here we do not need to reverse any permutation since this is
-        # already done in the setup_input method above!
-
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the matrix loop
-        # this reinterprets the output of the previous iteration.
-        inp = buffer.get_temporary("buff_step{}_in".format(l),
-                                   shape=shape + vec_shape,
-                                   dim_tags=ftags,
-                                   )
-
-        # The input temporary will only be read from, so we need to silence
-        # the loopy warning
-        silenced_warning('read_no_write({})'.format(inp))
-
-        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
-
     @property
     def function_name_suffix(self):
         if get_form_option("fastdg"):
diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py
index 86528227..8b714a41 100644
--- a/python/dune/codegen/sumfact/geometry.py
+++ b/python/dune/codegen/sumfact/geometry.py
@@ -172,7 +172,7 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
 
         return insn_dep.union(frozenset({insn}))
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
         # Get a temporary that interprets the base storage of the input
         # as a column-major matrix. In later iteration of the matrix loop
         # this reinterprets the output of the previous iteration.
diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py
index 73adc444..d1f10d6f 100644
--- a/python/dune/codegen/sumfact/realization.py
+++ b/python/dune/codegen/sumfact/realization.py
@@ -195,10 +195,10 @@ def realize_sumfact_kernel_function(sf):
             input_summand = sf.interface.realize_direct_input(inp_shape, input_inames)
         elif l == 0:
             # TODO: Simplify arguments!
-            input_summand = sf.interface.realize_input(inp_shape,
-                                                       input_inames,
-                                                       vec_shape,
+            input_summand = sf.interface.realize_input(input_inames,
+                                                       inp_shape,
                                                        vec_iname,
+                                                       vec_shape,
                                                        buffer,
                                                        ftags,
                                                        l,
@@ -238,36 +238,21 @@ def realize_sumfact_kernel_function(sf):
         # In case of direct output we directly accumulate the result
         # of the Sumfactorization into some global data structure.
         if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
-            # TODO: Move permutations to interface!
-            output_inames = permute_backward(output_inames, sf.cost_permutation)
-            output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation)
-
             if sf.vectorized:
                 insn_args["forced_iname_deps"] = frozenset({vec_iname[0].name})
             insn_dep = sf.interface.realize_direct_output(matprod, output_inames, out_shape, **insn_args)
         elif l == len(matrix_sequence) - 1:
-            # TODO: Move permutations to interface!
-            output_inames = permute_backward(output_inames, sf.cost_permutation)
             output_shape = tuple(out_shape[1:]) + (out_shape[0],)
-
-            # TODO: Move permutations to interface
-            output_shape = permute_backward(output_shape, sf.cost_permutation)
-            if sf.stage == 3:
-                output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation)
-                output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation)
-
-            out = buffer.get_temporary("buff_step{}_out".format(l),
-                                       shape=output_shape + vec_shape,
-                                       dim_tags=ftags,
-                                       )
-
-            # Issue the reduction instruction that implements the multiplication
-            # at the same time store the instruction ID for the next instruction to depend on
-            insn_dep = frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), output_inames + vec_iname),
-                                              expression=matprod,
-                                              **insn_args
-                                              )
-                                  })
+            insn_dep = sf.interface.realize_output(matprod,
+                                                   output_inames,
+                                                   output_shape,
+                                                   vec_iname,
+                                                   vec_shape,
+                                                   buffer,
+                                                   ftags,
+                                                   l,
+                                                   **insn_args,
+                                                   )
         else:
             output_shape = tuple(out_shape[1:]) + (out_shape[0],)
             out = buffer.get_temporary("buff_step{}_out".format(l),
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 440de43a..8c7052ad 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -2,6 +2,7 @@
 
 from dune.codegen.options import get_form_option, get_option
 from dune.codegen.generation import (get_counted_variable,
+                                     instruction,
                                      silenced_warning,
                                      subst_rule,
                                      transform,
@@ -43,12 +44,12 @@ class SumfactKernelInterfaceBase(object):
         """
         raise NotImplementedError
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
         """Interpret the input of sumfact kernel function in the right way (non fastdgg)
 
         This happens inside the sumfact kernel function.
 
-        TODO: Cleanup input
+        TODO: Cleanup arguments
         TODO: Add note about permutation
         TODO: Document input arguments
         """
@@ -64,24 +65,52 @@ class SumfactKernelInterfaceBase(object):
         """
         raise NotImplementedError
 
-    def realize_direct_output(self, result, iname, shape, which=0, **args):
-        """Accumulate results directly in the sumfact kernel function (fastdg)
+    def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
+        """Generate accumulate instruction after sumfact kernel function (non fastdg)
+
+        This happens after the function call.
+
+        TODO: Add note about permutation
+        TODO: Document input arguments
+        """
+        raise NotImplementedError
+
+    def realize_output(self, result, inames, shape, vec_iname, vec_shape, buffer, ftags, l, **args):
+        """Handle the output of the last tensor contraction in the sumfact kernel function the right way
 
         This happens inside the sumfact kernel function.
 
+        TODO: Cleanup arguments
         TODO: Add note about permutation
         TODO: Document input arguments
         """
+        inames = permute_backward(inames, self.cost_permutation)
+        shape = permute_backward(shape, self.cost_permutation)
+        if self.stage == 3:
+            inames = permute_backward(inames, self.quadrature_permutation)
+            shape = permute_backward(shape, self.quadrature_permutation)
 
-    def setup_output(self, sf, result, insn_dep, inames=None, additional_inames=()):
-        """Generate accumulate instruction after sumfact kernel function (non fastdg)
+        out = buffer.get_temporary("buff_step{}_out".format(l),
+                                   shape=shape + vec_shape,
+                                   dim_tags=ftags,
+                                   )
 
-        This happens after the function call.
+        # Issue the reduction instruction that implements the multiplication
+        # at the same time store the instruction ID for the next instruction to depend on
+        return frozenset({instruction(assignee=prim.Subscript(prim.Variable(out), inames + vec_iname),
+                                      expression=result,
+                                      **args
+                                      )
+                          })
+
+    def realize_direct_output(self, result, iname, shape, which=0, **args):
+        """Accumulate results directly in the sumfact kernel function (fastdg)
+
+        This happens inside the sumfact kernel function.
 
         TODO: Add note about permutation
         TODO: Document input arguments
         """
-        raise NotImplementedError
 
     @property
     def quadrature_permutation(self):
@@ -178,7 +207,7 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
         for i in self.interfaces:
             assert i.cost_permutation == cost_permutation
 
-        return vector_cost_permutation
+        return self.vector_cost_permutation
 
     @property
     def stage(self):
@@ -193,6 +222,23 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
             dep = dep.union(inp.setup_input(sf, dep, index=i))
         return dep
 
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
+        # TODO: vector_cost_permutation not used!
+
+        # Get a temporary that interprets the base storage of the input
+        # as a column-major matrix. In later iteration of the matrix loop
+        # this reinterprets the output of the previous iteration.
+        inp = buffer.get_temporary("buff_step{}_in".format(l),
+                                   shape=shape + vec_shape,
+                                   dim_tags=ftags,
+                                   )
+
+        # The input temporary will only be read from, so we need to silence
+        # the loopy warning
+        silenced_warning('read_no_write({})'.format(inp))
+
+        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
+
     def realize_direct_input(self, shape, inames):
         # TODO: vector_cost_permutation not used!
 
@@ -222,23 +268,6 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
             # need to load scalars into the SIMD vector.
             raise NotImplementedError("SIMD loads from scalars not implemented!")
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
-        # TODO: vector_cost_permutation not used!
-
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the matrix loop
-        # this reinterprets the output of the previous iteration.
-        inp = buffer.get_temporary("buff_step{}_in".format(l),
-                                   shape=shape + vec_shape,
-                                   dim_tags=ftags,
-                                   )
-
-        # The input temporary will only be read from, so we need to silence
-        # the loopy warning
-        silenced_warning('read_no_write({})'.format(inp))
-
-        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
-
     @property
     def function_args(self):
         return sum((i.function_args for i in remove_duplicates(self.interfaces)), ())
@@ -262,11 +291,15 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
 class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
     def __init__(self, interfaces, perm):
         self.interfaces = interfaces
-        self.vector_cost_permutation = perm
+        self._cost_permutation = perm
 
     def __repr__(self):
         return "_".join(repr(o) for o in self.interfaces)
 
+    @property
+    def cost_permutation(self):
+        return self._cost_permutation
+
     @property
     def quadrature_permutation(self):
         # TODO: For now we assure that all kerneles have the same quadrature_permutation
@@ -293,29 +326,29 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
 
         return prim.Call(prim.Variable(hadd_function), (result,))
 
-    def setup_output(self, sf, result, insn_dep):
-        # TODO: vector_cost_permutation not used!
-
-        outputs = set(self.interfaces)
+    def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
+        # TODO: Include permutations of scalar kernels as soon as they could be different
+        shape = permute_backward(shape, self.cost_permutation)
+        inames = permute_backward(inames, self.cost_permutation)
 
-        trial_element, = set(o.trial_element for o in self.interfaces)
-        trial_element_index = set(o.trial_element_index for o in self.interfaces).pop()
-        from dune.codegen.sumfact.accumulation import accum_iname
-        element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
-        inames = tuple(accum_iname(element, mat.rows, i)
-                       for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted))
-        veciname = accum_iname(element, sf.vector_width // len(outputs), "vec")
-        transform(lp.tag_inames, [(veciname, "vec")])
+        # Get a temporary that interprets the base storage of the input
+        # as a column-major matrix. In later iteration of the matrix loop
+        # this reinterprets the output of the previous iteration.
+        inp = buffer.get_temporary("buff_step{}_in".format(l),
+                                   shape=shape + vec_shape,
+                                   dim_tags=ftags,
+                                   )
 
-        deps = frozenset()
-        for o in outputs:
-            hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))))
-            deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,)))
+        # The input temporary will only be read from, so we need to silence
+        # the loopy warning
+        silenced_warning('read_no_write({})'.format(inp))
 
-        return deps
+        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
 
     def realize_direct_output(self, result, inames, shape, **args):
-        # TODO: vector_cost_permutation not used!
+        # TODO: Find out what needs to happen here
+        # inames = permute_backward(inames, self.cost_permutation)
+        # shape = permute_backward(shape, self.cost_permutation)
 
         outputs = set(self.interfaces)
 
@@ -335,24 +368,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
 
         return deps
 
-    def realize_input(self, shape, inames, vec_shape, vec_iname, buffer, ftags, l):
-        # TODO: Include permutations of scalar kernels as soon as they could be different
-        shape = permute_backward(shape, self.vector_cost_permutation)
-        inames = permute_backward(inames, self.vector_cost_permutation)
+    def setup_output(self, sf, result, insn_dep):
+        # TODO: vector_cost_permutation not used!
 
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the matrix loop
-        # this reinterprets the output of the previous iteration.
-        inp = buffer.get_temporary("buff_step{}_in".format(l),
-                                   shape=shape + vec_shape,
-                                   dim_tags=ftags,
-                                   )
+        outputs = set(self.interfaces)
 
-        # The input temporary will only be read from, so we need to silence
-        # the loopy warning
-        silenced_warning('read_no_write({})'.format(inp))
+        trial_element, = set(o.trial_element for o in self.interfaces)
+        trial_element_index = set(o.trial_element_index for o in self.interfaces).pop()
+        from dune.codegen.sumfact.accumulation import accum_iname
+        element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
+        inames = tuple(accum_iname(element, mat.rows, i)
+                       for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted))
+        veciname = accum_iname(element, sf.vector_width // len(outputs), "vec")
+        transform(lp.tag_inames, [(veciname, "vec")])
 
-        return prim.Subscript(prim.Variable(inp), inames + vec_iname)
+        deps = frozenset()
+        for o in outputs:
+            hadd_result = self._add_hadd(o, maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))))
+            deps = deps.union(o.setup_output(sf, hadd_result, insn_dep, inames=inames, additional_inames=(veciname,)))
+
+        return deps
 
     @property
     def function_args(self):
-- 
GitLab