From 0001fb038cdf0800cc82aa4b147ffa357cc6c229 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 19 Apr 2017 17:00:32 +0200
Subject: [PATCH] Move assembly of input for sf kernels to a dedicated object

---
 python/dune/perftool/sumfact/accumulation.py |  9 +-
 python/dune/perftool/sumfact/basis.py        | 86 +++++++++++++++++---
 python/dune/perftool/sumfact/realization.py  | 47 +----------
 python/dune/perftool/sumfact/symbolic.py     | 39 ++++-----
 4 files changed, 101 insertions(+), 80 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index d24e100c..f01d2bd3 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -29,12 +29,13 @@ from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
 from dune.perftool.sumfact.switch import (get_facedir,
                                           get_facemod,
                                           )
-from dune.perftool.sumfact.symbolic import SumfactKernel
+from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
 from dune.perftool.ufl.modified_terminals import extract_modified_arguments
 from dune.perftool.tools import get_pymbolic_basename
 from dune.perftool.error import PerftoolError
 from dune.perftool.sumfact.quadrature import quadrature_inames
 
+from pytools import ImmutableRecord
 
 import loopy as lp
 import numpy as np
@@ -59,6 +60,10 @@ def accum_iname(restriction, bound, i):
     return sumfact_iname(bound, "accum")
 
 
+class AlreadyAssembledInput(SumfactKernelInputBase, ImmutableRecord):
+    pass
+
+
 @backend(interface="accum_insn", name="sumfact")
 def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     # When doing sumfactorization we want to split the test function
@@ -126,7 +131,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                            preferred_position=indices[-1] if accterm.new_indices else None,
                            accumvar=accum,
                            within_inames=jacobian_inames,
-                           coeff_func_index=coeff_func_index,
+                           input=AlreadyAssembledInput(index=coeff_func_index),
                            )
 
         from dune.perftool.sumfact.vectorization import attach_vectorization_info
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index cb318c7e..d7687986 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -33,19 +33,77 @@ from dune.perftool.pdelab.geometry import (local_dimension,
                                            world_dimension,
                                            )
 from dune.perftool.loopy.buffer import initialize_buffer
-from dune.perftool.sumfact.symbolic import SumfactKernel
+from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.restriction import restricted_name
+from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound, lfs_child, name_leaf_lfs
 from dune.perftool.tools import maybe_wrap_subscript
+from dune.perftool.pdelab.basis import shape_as_pymbolic
+from dune.perftool.sumfact.accumulation import sumfact_iname
 
-from pytools import product
+from ufl.functionview import select_subelement
+from ufl import VectorElement, TensorElement
+
+from pytools import product, ImmutableRecord
 
 from loopy.match import Writes
 
 import pymbolic.primitives as prim
 
 
+class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
+    def __init__(self,
+                 coeff_func=None,
+                 coeff_func_index=None,
+                 element=None,
+                 component=None,
+                 restriction=0,
+                 ):
+        ImmutableRecord.__init__(self,
+                                 coeff_func=coeff_func,
+                                 coeff_func_index=coeff_func_index,
+                                 element=element,
+                                 component=component,
+                                 restriction=restriction,
+                                 )
+
+    def realize(self, sf, index, insn_dep):
+        lfs = name_lfs(self.element, self.restriction, self.component)
+        sub_element = select_subelement(self.element, self.component)
+        shape = sub_element.value_shape() + (self.element.cell().geometric_dimension(),)
+
+        if isinstance(sub_element, (VectorElement, TensorElement)):
+            # Could be 0 but shouldn't be None
+            assert self.coeff_func_index is not None
+
+            lfs_pym = lfs_child(lfs,
+                                (self.coeff_func_index,),
+                                shape=shape_as_pymbolic(shape[:-1]),
+                                symmetry=self.element.symmetry())
+
+        leaf_element = sub_element
+        if isinstance(sub_element, (VectorElement, TensorElement)):
+            leaf_element = sub_element.sub_elements()[0]
+
+        lfs = name_leaf_lfs(leaf_element, self.restriction)
+        basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
+        container = self.coeff_func(self.restriction)
+        if isinstance(sub_element, (VectorElement, TensorElement)):
+            from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
+            coeff = pc(container, lfs_pym, basisiname)
+        else:
+            from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
+            coeff = pc(container, lfs, basisiname)
+
+        assignee = prim.Subscript(prim.Variable("input_{}".format(sf.buffer)),
+                                  (prim.Variable(basisiname),) + (index,))
+        instruction(assignee=assignee,
+                    expression=coeff,
+                    depends_on=sf.insn_dep.union(insn_dep),
+                    tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
+                    )
+
 def name_sumfact_base_buffer():
     count = get_counter('sumfact_base_buffer')
     name = "buffer_{}".format(str(count))
@@ -102,14 +160,17 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         if len(indices) == 2:
             coeff_func_index = indices[0]
 
+        inp = LFSSumfactKernelInput(coeff_func=coeff_func,
+                                    coeff_func_index=coeff_func_index,
+                                    element=element,
+                                    component=component,
+                                    restriction=restriction,
+                                    )
+
         # The sum factorization kernel object gathering all relevant information
         sf = SumfactKernel(matrix_sequence=matrix_sequence,
-                           restriction=restriction,
                            preferred_position=indices[-1],
-                           coeff_func=coeff_func,
-                           coeff_func_index=coeff_func_index,
-                           element=element,
-                           component=component,
+                           input=inp,
                            )
 
         from dune.perftool.sumfact.vectorization import attach_vectorization_info
@@ -156,11 +217,14 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor_in
                                                       facemod=get_facemod(restriction),
                                                       basis_size=basis_size)
 
+    inp = LFSSumfactKernelInput(coeff_func=coeff_func,
+                                element=element,
+                                component=component,
+                                restriction=restriction,
+                                )
+
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
-                       restriction=restriction,
-                       coeff_func=coeff_func,
-                       element=element,
-                       component=component,
+                       input=inp,
                        )
 
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 4ff4dfb4..c03c5321 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -2,9 +2,6 @@
 The code that triggers the creation of the necessary code constructs
 to realize a sum factorization kernel
 """
-from ufl.functionview import select_subelement
-from ufl import VectorElement, TensorElement
-
 from dune.perftool.generation import (barrier,
                                       dump_accumulate_timer,
                                       generator_factory,
@@ -21,7 +18,6 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary,
 from dune.perftool.pdelab.argument import pymbolic_coefficient
 from dune.perftool.pdelab.basis import shape_as_pymbolic
 from dune.perftool.pdelab.geometry import world_dimension
-from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound, lfs_child, name_leaf_lfs
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.signatures import assembler_routine_name
 from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
@@ -63,7 +59,7 @@ def _realize_sum_factorization_kernel(sf):
 
     # Set up the input for stage 1
     if sf.stage == 1 and not get_option("fastdg"):
-        assert sf.coeff_func
+        assert sf.input
 
         # Get the input temporary!
         input_setup = get_buffer_temporary(sf.buffer,
@@ -71,53 +67,18 @@ def _realize_sum_factorization_kernel(sf):
                                            name="input_{}".format(sf.buffer)
                                            )
 
-        def _write_input(inputsf, index=0):
-            # Write initial coefficients into buffer
-            lfs = name_lfs(inputsf.element, inputsf.restriction, inputsf.component)
-
-            sub_element = select_subelement(inputsf.element, inputsf.component)
-            shape = sub_element.value_shape() + (inputsf.element.cell().geometric_dimension(),)
-
-            if isinstance(sub_element, (VectorElement, TensorElement)):
-                # Could be 0 but shouldn't be None
-                assert inputsf.coeff_func_index is not None
-
-                lfs_pym = lfs_child(lfs,
-                                    (inputsf.coeff_func_index,),
-                                    shape=shape_as_pymbolic(shape[:-1]),
-                                    symmetry=inputsf.element.symmetry())
-
-            leaf_element = sub_element
-            if isinstance(sub_element, (VectorElement, TensorElement)):
-                leaf_element = sub_element.sub_elements()[0]
-
-            lfs = name_leaf_lfs(leaf_element, inputsf.restriction)
-
-            basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
-            container = inputsf.coeff_func(inputsf.restriction)
-            if isinstance(sub_element, (VectorElement, TensorElement)):
-                coeff = pymbolic_coefficient(container, lfs_pym, basisiname)
-            else:
-                coeff = pymbolic_coefficient(container, lfs, basisiname)
-            assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + (index,))
-            instruction(assignee=assignee,
-                        expression=coeff,
-                        depends_on=inputsf.insn_dep.union(insn_dep),
-                        tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
-                        )
-
         if sf.vectorized:
             for i, inputsf in enumerate(sf.kernels):
-                _write_input(inputsf, i)
+                inputsf.input.realize(inputsf, i, insn_dep)
         else:
-            _write_input(sf)
+            sf.input.realize(sf, 0, insn_dep)
 
         insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
 
     # Construct the direct_input for the FastDG case
     direct_input = None
     if get_option('fastdg') and sf.stage == 1:
-        direct_input = sf.coeff_func(sf.restriction)
+        direct_input = sf.input.coeff_func(sf.input.restriction)
 
     direct_output = None
     if get_option('fastdg') and sf.stage == 3:
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index b254d5c6..88f884c8 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -14,6 +14,10 @@ import frozendict
 import inspect
 
 
+class SumfactKernelInputBase(object):
+    pass
+
+
 class SumfactKernelBase(object):
     pass
 
@@ -24,13 +28,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
                  buffer=None,
                  stage=1,
                  preferred_position=None,
-                 restriction=0,
+                 restriction=None,
                  within_inames=(),
                  insn_dep=frozenset(),
-                 coeff_func=None,
-                 coeff_func_index=None,
-                 element=None,
-                 component=None,
+                 input=None,
                  accumvar=None,
                  ):
         """Create a sum factorization kernel
@@ -101,6 +102,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         element: The UFL element
         component: The treepath to the correct component of above element
         accumvar: The accumulation variable to accumulate into
+        input: An SumfactKernelInputBase instance describing the input of the kernel
         """
         # Assert the inputs!
         assert isinstance(matrix_sequence, tuple)
@@ -112,6 +114,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         if preferred_position is not None:
             assert isinstance(preferred_position, int)
 
+        if stage == 1:
+            assert isinstance(input, SumfactKernelInputBase)
+            restriction = input.restriction
+
         if stage == 3:
             assert isinstance(restriction, tuple)
 
@@ -161,7 +167,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         work on the same input coefficient (and are suitable for simultaneous
         treatment because of that)
         """
-        return (self.restriction, self.stage, self.coeff_func, self.coeff_func_index, self.element, self.component, self.accumvar)
+        return (self.input, self.restriction, self.accumvar)
 
     #
     # Some convenience methods to extract information about the sum factorization kernel
@@ -380,24 +386,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
         return self.kernels[0].within_inames
 
     @property
-    def coeff_func(self):
-        assert len(set(k.coeff_func for k in self.kernels)) == 1
-        return self.kernels[0].coeff_func
-
-    @property
-    def coeff_func_index(self):
-        assert len(set(k.coeff_func_index for k in self.kernels)) == 1
-        return self.kernels[0].coeff_func_index
-
-    @property
-    def element(self):
-        assert len(set(k.element for k in self.kernels)) == 1
-        return self.kernels[0].element
-
-    @property
-    def component(self):
-        assert len(set(k.component for k in self.kernels)) == 1
-        return self.kernels[0].component
+    def input(self):
+        assert len(set(k.input for k in self.kernels)) == 1
+        return self.kernels[0].input
 
     @property
     def accumvar(self):
-- 
GitLab