From 8d2b63838311597d409ebadc1db2b48374204ce4 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 15 Feb 2018 13:50:48 +0100
Subject: [PATCH] Move accumulation code onto output object

---
 python/dune/perftool/sumfact/accumulation.py | 117 ++++++++++---------
 python/dune/perftool/sumfact/symbolic.py     |  28 ++++-
 python/dune/perftool/tools.py                |  11 ++
 3 files changed, 101 insertions(+), 55 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index a01fb8ea..0b189892 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -15,6 +15,7 @@ from dune.perftool.generation import (backend,
                                       kernel_cached,
                                       temporary_variable,
                                       transform,
+                                      valuearg
                                       )
 from dune.perftool.options import (get_form_option,
                                    get_option,
@@ -26,6 +27,7 @@ from dune.perftool.pdelab.localoperator import determine_accumulation_space
 from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.signatures import assembler_routine_name
 from dune.perftool.pdelab.geometry import world_dimension
+from dune.perftool.pdelab.spaces import name_lfs
 from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
                                               construct_basis_matrix_sequence,
                                               )
@@ -34,7 +36,7 @@ from dune.perftool.sumfact.switch import (get_facedir,
                                           )
 from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelOutputBase
 from dune.perftool.ufl.modified_terminals import extract_modified_arguments
-from dune.perftool.tools import get_pymbolic_basename
+from dune.perftool.tools import get_pymbolic_basename, get_leaf
 from dune.perftool.error import PerftoolError
 from dune.perftool.sumfact.quadrature import quadrature_inames
 
@@ -90,7 +92,7 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
                  trial_element=None,
                  trial_element_index=None,
                  ):
-        #TODO: Isnt accumvar superfluous in the presence of all the other infos?
+        # TODO: Isnt accumvar superfluous in the presence of all the other infos?
         ImmutableRecord.__init__(self,
                                  accumvar=accumvar,
                                  restriction=None,
@@ -106,10 +108,65 @@ class AccumulationOutput(SumfactKernelOutputBase, ImmutableRecord):
             return ()
         else:
             from dune.perftool.sumfact.basis import lfs_inames
-            element = self.trial_element
-            if isinstance(element, MixedElement):
-                element = element.extract_component(self.trial_element_index)[1]
-            return lfs_inames(element, self.restriction)
+            return lfs_inames(get_leaf(self.trial_element, self.trial_element_index), self.restriction)
+
+
+    def realize(self, sf, result, insn_dep, inames=None, additional_inames=()):
+        trial_leaf_element = get_leaf(self.trial_element, self.trial_element_index) if self.trial_element is not None else None
+
+        basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence)
+
+        if inames is None:
+            inames = tuple(accum_iname(trial_leaf_element, mat.rows, i)
+                           for i, mat in enumerate(sf.matrix_sequence))
+
+            # Determine the expression to accumulate with. This depends on the vectorization strategy!
+            from dune.perftool.tools import maybe_wrap_subscript
+            result = maybe_wrap_subscript(result, tuple(prim.Variable(i) for i in inames))
+
+        # Collect the lfs and lfs indices for the accumulate call
+        restriction = (0, 0) if self.restriction is None else self.restriction
+        test_lfs = name_lfs(self.test_element, restriction[0], self.test_element_index)
+        valuearg(test_lfs, dtype=lp.types.NumpyType("str"))
+        test_lfs_index = flatten_index(tuple(prim.Variable(i) for i in inames),
+                                       basis_size,
+                                       order="f"
+                                       )
+
+        accum_args = [prim.Variable(test_lfs), test_lfs_index]
+
+        # In the jacobian case, also determine the space for the ansatz space
+        if sf.within_inames:
+            # TODO the next line should get its inames from
+            # elsewhere. This is *NOT* robust (but works right now)
+            ansatz_lfs = name_lfs(self.trial_element, restriction[1], self.trial_element_index)
+            valuearg(ansatz_lfs, dtype=lp.types.NumpyType("str"))
+            from dune.perftool.sumfact.basis import _basis_functions_per_direction
+            ansatz_lfs_index = flatten_index(tuple(prim.Variable(sf.within_inames[i])
+                                                   for i in range(world_dimension())),
+                                             _basis_functions_per_direction(trial_leaf_element),
+                                             order="f"
+                                             )
+
+            accum_args.append(prim.Variable(ansatz_lfs))
+            accum_args.append(ansatz_lfs_index)
+
+        accum_args.append(result)
+
+        if not get_form_option("fastdg"):
+            rank = 2 if self.within_inames else 1
+            expr = prim.Call(PDELabAccumulationFunction(self.accumvar, rank),
+                             tuple(accum_args)
+                             )
+            instruction(assignees=(),
+                        expression=expr,
+                        forced_iname_deps=frozenset(inames + additional_inames + self.within_inames),
+                        forced_iname_deps_is_final=True,
+                        depends_on=insn_dep,
+                        predicates=sf.predicates
+                        )
+
+        return frozenset()
 
 
 class SumfactAccumulationInfo(ImmutableRecord):
@@ -358,56 +415,10 @@ def generate_accumulation_instruction(expr, visitor):
                                           depends_on=insn_dep,
                                           within_inames=frozenset(jacobian_inames))})
 
-    inames = tuple(accum_iname(trial_leaf_element, mat.rows, i)
-                   for i, mat in enumerate(vsf.matrix_sequence))
-
-    # Collect the lfs and lfs indices for the accumulate call
-    test_lfs.index = flatten_index(tuple(prim.Variable(i) for i in inames),
-                                   basis_size,
-                                   order="f"
-                                   )
-
-    # In the jacobian case, also determine the space for the ansatz space
-    if jacobian_inames:
-        # TODO the next line should get its inames from
-        # elsewhere. This is *NOT* robust (but works right now)
-        from dune.perftool.sumfact.basis import _basis_functions_per_direction
-        ansatz_lfs.index = flatten_index(tuple(prim.Variable(jacobian_inames[i])
-                                               for i in range(world_dimension())),
-                                         _basis_functions_per_direction(trial_leaf_element),
-                                         order="f"
-                                         )
-
     # Add a sum factorization kernel that implements the multiplication
     # with the test function (stage 3)
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     result, insn_dep = realize_sum_factorization_kernel(vsf.copy(insn_dep=vsf.insn_dep.union(insn_dep)))
 
-    # Determine the expression to accumulate with. This depends on the vectorization strategy!
-    result = prim.Subscript(result, tuple(prim.Variable(i) for i in inames))
-    vecinames = ()
-
-    if vsf.vectorized:
-        iname = accum_iname(trial_leaf_element, vsf.vector_width, "vec")
-        vecinames = (iname,)
-        transform(lp.tag_inames, [(iname, "vec")])
-        from dune.perftool.tools import maybe_wrap_subscript
-        result = prim.Call(prim.Variable("horizontal_add"),
-                           (maybe_wrap_subscript(result, prim.Variable(iname)),),
-                           )
-
     if not get_form_option("fastdg"):
-        rank = 2 if jacobian_inames else 1
-        expr = prim.Call(PDELabAccumulationFunction(accumvar, rank),
-                         (test_lfs.get_args() +
-                          ansatz_lfs.get_args() +
-                          (result,)
-                          )
-                         )
-        instruction(assignees=(),
-                    expression=expr,
-                    forced_iname_deps=frozenset(inames + vecinames + jacobian_inames),
-                    forced_iname_deps_is_final=True,
-                    depends_on=insn_dep,
-                    predicates=predicates
-                    )
+        vsf.output.realize(vsf, result, insn_dep)
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 322806fa..2b9c2e21 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -1,12 +1,15 @@
 """ A pymbolic node representing a sum factorization kernel """
 
 from dune.perftool.options import get_option
-from dune.perftool.generation import get_counted_variable
+from dune.perftool.generation import (get_counted_variable,
+                                      transform,
+                                      )
 from dune.perftool.pdelab.geometry import local_dimension, world_dimension
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray
 from dune.perftool.loopy.target import dtype_floatingpoint
 from dune.perftool.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad
+from dune.perftool.tools import get_leaf
 
 from pytools import ImmutableRecord, product
 
@@ -81,7 +84,7 @@ class SumfactKernelOutputBase(object):
     def within_inames(self):
         return ()
 
-    def realize(self, sf, dep):
+    def realize(self, sf, result, insn_dep):
         return dep
 
     def realize_direct(self):
@@ -92,6 +95,27 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
     def __init__(self, outputs):
         self.outputs = outputs
 
+    def realize(self, sf, result, insn_dep):
+        outputs = set(self.outputs)
+        assert(len(outputs) == 1)
+
+        o, = outputs
+
+        from dune.perftool.sumfact.accumulation import accum_iname
+        element = get_leaf(o.trial_element, o.trial_element_index) if o.trial_element is not None else None
+        inames = tuple(accum_iname(element, mat.rows, i)
+                       for i, mat in enumerate(sf.matrix_sequence))
+
+        veciname = accum_iname(element, sf.vector_width, "vec")
+        transform(lp.tag_inames, [(veciname, "vec")])
+
+        from dune.perftool.tools import maybe_wrap_subscript
+        result = prim.Call(prim.Variable("horizontal_add"),
+                           (maybe_wrap_subscript(result, tuple(prim.Variable(iname) for iname in inames + (veciname,))),),
+                           )
+
+        return o.realize(sf, result, insn_dep, inames=inames, additional_inames=(veciname,))
+
 
 class SumfactKernelBase(object):
     pass
diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py
index e302f28c..b29ebe22 100644
--- a/python/dune/perftool/tools.py
+++ b/python/dune/perftool/tools.py
@@ -82,3 +82,14 @@ def list_diff(l1, l2):
             if item not in l2:
                 l.append(item)
         return l
+
+
+def get_leaf(element, index):
+    """ return a leaf element if the given element is a MixedElement """
+    leaf_element = element
+    from ufl import MixedElement
+    if isinstance(element, MixedElement):
+        assert isinstance(index, int)
+        leaf_element = element.extract_component(index)[1]
+
+    return leaf_element
-- 
GitLab