From 64d89b01759216dfd6f5d86d874788c57646c3a2 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 2 Nov 2016 17:22:46 +0100
Subject: [PATCH] Correctly accumulate sumfactorized residuals

---
 python/dune/perftool/generation/loopy.py   |  9 +++++-
 python/dune/perftool/loopy/target.py       | 16 +++++++++--
 python/dune/perftool/pdelab/spaces.py      |  2 +-
 python/dune/perftool/sumfact/quadrature.py |  9 +++---
 python/dune/perftool/sumfact/sumfact.py    | 32 ++++++++++++----------
 5 files changed, 45 insertions(+), 23 deletions(-)

diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 8afffba8..2b228a01 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -15,13 +15,20 @@ function_mangler = generator_factory(item_tags=("mangler",))
 silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True)
 
 
+class AccumulationGlobalArg(loopy.GlobalArg):
+    allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ['transform']
+
+
 @generator_factory(item_tags=("argument", "globalarg"),
                    cache_key_generator=lambda n, **kw: n)
 def globalarg(name, shape=loopy.auto, **kw):
     if isinstance(shape, str):
         shape = (shape,)
     dtype = kw.pop("dtype", numpy.float64)
-    return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw)
+    if 'transform' in kw:
+        return AccumulationGlobalArg(name, dtype=dtype, shape=shape, **kw)
+    else:
+        return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw)
 
 
 @generator_factory(item_tags=("argument", "constantarg"),
diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py
index 5dbbad91..f9789040 100644
--- a/python/dune/perftool/loopy/target.py
+++ b/python/dune/perftool/loopy/target.py
@@ -1,4 +1,5 @@
 from dune.perftool.loopy.temporary import DuneTemporaryVariable
+from dune.perftool.generation.loopy import AccumulationGlobalArg
 
 from loopy.target import (TargetBase,
                           ASTBuilderBase,
@@ -6,6 +7,10 @@ from loopy.target import (TargetBase,
                           )
 from loopy.target.c import CASTBuilder
 from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper, CExpressionToCodeMapper
+from loopy.symbolic import FunctionIdentifier
+
+from pymbolic.primitives import Call, Subscript, Variable
+
 
 _registry = {'float32': 'float',
              'int32': 'int',
@@ -16,10 +21,9 @@ _registry = {'float32': 'float',
 
 class MyMapper(ExpressionToCExpressionMapper):
     def map_subscript(self, expr, type_context):
-        temporary = self.find_array(expr)
-        if isinstance(temporary, DuneTemporaryVariable) and not temporary.managed:
+        arr = self.find_array(expr)
+        if isinstance(arr, DuneTemporaryVariable) and not arr.managed:
             # If there is but one index, we do not need to handle this
-            from pymbolic.primitives import Subscript, Variable
             if isinstance(expr.index, (Variable, int)):
                 return expr
 
@@ -28,6 +32,12 @@ class MyMapper(ExpressionToCExpressionMapper):
             for i in expr.index:
                 ret = Subscript(ret, i)
             return ret
+        elif isinstance(arr, AccumulationGlobalArg):
+            assert isinstance(arr.transform, FunctionIdentifier)
+            pseudo_subscript = Subscript(expr.aggregate, expr.index)
+            flattened = ExpressionToCExpressionMapper.map_subscript(self, pseudo_subscript, type_context)
+            transformed = ExpressionToCExpressionMapper.map_call(self, Call(arr.transform, (flattened.index,)), type_context)
+            return Subscript(flattened.aggregate + '.base()', transformed)
         else:
             return ExpressionToCExpressionMapper.map_subscript(self, expr, type_context)
 
diff --git a/python/dune/perftool/pdelab/spaces.py b/python/dune/perftool/pdelab/spaces.py
index 0e7b0ce2..f03243eb 100644
--- a/python/dune/perftool/pdelab/spaces.py
+++ b/python/dune/perftool/pdelab/spaces.py
@@ -221,7 +221,7 @@ class LFSLocalIndex(FunctionIdentifier):
 
     @property
     def name(self):
-        return '{}.local_index'.format(self.lfs)
+        return '{}.localIndex'.format(self.lfs)
 
 
 @function_mangler
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index ea7104a4..408c9c0b 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -100,15 +100,16 @@ def define_quadrature_position(name):
 
 
 @backend(interface="quad_pos", name="sumfact")
-def name_quadrature_position():
+def pymbolic_quadrature_position():
     formdata = get_global_context_value('formdata')
     dim = formdata.geometric_dimension
     name = 'pos'
     temporary_variable(name, shape=(dim,), shape_impl=("fv",))
     define_quadrature_position(name)
-    return name
+    return Variable(name)
 
 
 @backend(interface="qp_in_cell", name="sumfact")
-def name_quadrature_position_in_cell(restriction):
-    return name_quadrature_position()
+def pymbolic_quadrature_position_in_cell(restriction):
+    from dune.perftool.pdelab.geometry import to_cell_coordinates
+    return to_cell_coordinates(pymbolic_quadrature_position(), restriction)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 9a37ffd5..caca1bdc 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -3,6 +3,7 @@ from dune.perftool.pdelab.argument import (name_coefficientcontainer,
                                            )
 from dune.perftool.generation import (backend,
                                       domain,
+                                      function_mangler,
                                       get_counter,
                                       get_global_context_value,
                                       globalarg,
@@ -30,6 +31,7 @@ from pymbolic.primitives import (Call,
                                  )
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from loopy import Reduction
+from loopy.symbolic import FunctionIdentifier
 
 from pytools import product
 
@@ -121,22 +123,24 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     result = sum_factorization_kernel(a_matrices, "reffub", 2)
 
     from dune.perftool.pdelab.spaces import LFSLocalIndex
+
     # Now write all this into the correct residual
+    lfs = name_lfs(accterm.argument.argexpr.ufl_element(),
+                   accterm.argument.restriction,
+                   accterm.argument.component,
+                   )
     inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
-    globalarg("r", shape=(basis_functions_per_direction() ^ dim,))
-#     instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)),
-# #                 assignee=Subscript(Variable("r"), (Call(LFSLocalIndex("lfs"), tuple(Variable(i) for i in inames)),)),
-#                 assignee=Subscript(Variable("r.d"), (0,)),
-#                 forced_iname_deps=frozenset(inames),
-#                 forced_iname_deps_is_final=True,
-#                 depends_on=frozenset({stage_insn(3)}),
-#                 )
-
-#     # Do stage 3 (for f=u => mass matrix)
-#     theta_transposed = name_theta_transposed()
-#     a_matrix_transposed = AMatrix(theta_transposed, cols, rows)
-#     a_matrices_transposed = (a_matrix_transposed, a_matrix_transposed)
-#     var = sum_factorization_kernel(a_matrices_transposed, "buffer", 2)
+    globalarg("r",
+              shape=(basis_functions_per_direction(),) * dim,
+              transform=LFSLocalIndex(lfs),
+              )
+
+    instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)),
+                assignee=Subscript(Variable('r'), tuple(Variable(i) for i in inames)),
+                forced_iname_deps=frozenset(inames),
+                forced_iname_deps_is_final=True,
+                depends_on=frozenset({stage_insn(3)}),
+                )
 
 
 def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})):
-- 
GitLab