From 11ecb83965516fb7fe7425e4436ca9c9da3b120a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 7 Nov 2016 12:59:00 +0100
Subject: [PATCH] Implement correct accumulation (also for jacobians!)

---
 python/dune/perftool/generation/loopy.py   |  4 +-
 python/dune/perftool/loopy/flatten.py      | 52 ++++++++++++++++++++++
 python/dune/perftool/loopy/target.py       |  6 ---
 python/dune/perftool/sumfact/quadrature.py | 11 ++++-
 python/dune/perftool/sumfact/sumfact.py    | 45 ++++++++++++-------
 5 files changed, 92 insertions(+), 26 deletions(-)
 create mode 100644 python/dune/perftool/loopy/flatten.py

diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 4f237d61..451149cf 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -18,11 +18,11 @@ silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=Tr
 
 @generator_factory(item_tags=("argument", "globalarg"),
                    cache_key_generator=lambda n, **kw: n)
-def globalarg(name, shape=loopy.auto, argtype=loopy.GlobalArg, **kw):
+def globalarg(name, shape=loopy.auto, **kw):
     if isinstance(shape, str):
         shape = (shape,)
     dtype = kw.pop("dtype", numpy.float64)
-    return argtype(name, dtype=dtype, shape=shape, **kw)
+    return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw)
 
 
 @generator_factory(item_tags=("argument", "constantarg"),
diff --git a/python/dune/perftool/loopy/flatten.py b/python/dune/perftool/loopy/flatten.py
new file mode 100644
index 00000000..5e5415d8
--- /dev/null
+++ b/python/dune/perftool/loopy/flatten.py
@@ -0,0 +1,52 @@
+from loopy.kernel.array import (convert_computed_to_fixed_dim_tags,
+                                get_access_info,
+                                parse_array_dim_tags,
+                                )
+
+
+class _DummyArrayObject(object):
+    def __init__(self, dim_tags):
+        self.name = 'isthiseverused'
+        self.offset = None
+        self.dim_tags = dim_tags
+
+    def num_target_axes(self):
+        return 1
+
+    def vector_size(self, target):
+        # This should call something on the target instead
+        return 1
+
+
+def flatten_index(index, shape, order="c"):
+    """
+    A function that flattens a multiindex given the shape
+    of the multi dimensional array, a tuple of indices and
+    the specification of the axis order ("c" for row major,
+    "f" for column major)
+
+    Loopy of course does this automatically in a lot of places.
+    This code is only meant to be used if a flat index needs
+    to be manually created.
+    """
+    assert order in ("c", "f")
+    assert len(index) == len(shape)
+
+    # Get a tuple of dim tags with the specified order
+    dim_tags = parse_array_dim_tags(",".join(order for i in index))
+
+    # Transform them to fixed stride tags
+    dim_tags = convert_computed_to_fixed_dim_tags("blubber",  # Name unused
+                                                  len(index),  # number of user axes
+                                                  1,  # number of implementation axes
+                                                  shape,
+                                                  dim_tags,
+                                                  )
+    accinfo = get_access_info(None,  # the target fed into above _DummyArrayObject.vector_size
+                              _DummyArrayObject(dim_tags),  # the array duck
+                              index,
+                              lambda x: x,  # eval_expr, semantics unclear
+                              None,  # vectorization info
+                              )
+
+    return accinfo.subscripts[0]
diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py
index 2785737f..35097ac7 100644
--- a/python/dune/perftool/loopy/target.py
+++ b/python/dune/perftool/loopy/target.py
@@ -1,5 +1,4 @@
 from dune.perftool.loopy.temporary import DuneTemporaryVariable
-from dune.perftool.sumfact.sumfact import AccumulationArg
 from dune.perftool.pdelab.spaces import LFSLocalIndex
 
 from loopy.target import (TargetBase,
@@ -33,11 +32,6 @@ class MyMapper(ExpressionToCExpressionMapper):
             for i in expr.index:
                 ret = Subscript(ret, i)
             return ret
-        elif isinstance(arr, AccumulationArg):
-            pseudo_subscript = Subscript(expr.aggregate, expr.index)
-            flattened = ExpressionToCExpressionMapper.map_subscript(self, pseudo_subscript, type_context)
-            transformed = ExpressionToCExpressionMapper.map_call(self, Call(LFSLocalIndex(arr.lfs), (flattened.index,)), type_context)
-            return Subscript(Variable(flattened.aggregate.name + '.base()'), transformed)
         else:
             return ExpressionToCExpressionMapper.map_subscript(self, expr, type_context)
 
diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py
index ccb0c375..b59aa21f 100644
--- a/python/dune/perftool/sumfact/quadrature.py
+++ b/python/dune/perftool/sumfact/quadrature.py
@@ -59,6 +59,15 @@ def base_weight_function_mangler(target, func, dtypes):
         return CallMangleInfo(func.name, (NumpyType(numpy.float64),), ())
 
 
+def pymbolic_base_weight():
+    """ This is the base weight that should be multiplied to the quadrature
+    weight. With the fast DG assembler this will handle the weighting of the
+    time discretization scheme.
+    TODO: Introduce backend switch that uses above BaseWeight function
+    """
+    return 1.0
+
+
 @iname
 def sumfact_quad_iname(d, context):
     name = "quad_{}_{}".format(context, d)
@@ -92,7 +101,7 @@ def recursive_quadrature_weight(dir=0):
     formdata = get_global_context_value('formdata')
     dim = formdata.geometric_dimension
     if dir == dim:
-        return Call(BaseWeight(name_accumulation_variable()), ())
+        return pymbolic_base_weight()
     else:
         name = 'weight_{}'.format(dir)
         define_recursive_quadrature_weight(name, dir)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 1e69bc32..fe489784 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -1,6 +1,7 @@
 from dune.perftool.pdelab.argument import (name_accumulation_variable,
                                            name_coefficientcontainer,
                                            pymbolic_coefficient,
+                                           PDELabAccumulationFunction,
                                            )
 from dune.perftool.generation import (backend,
                                       domain,
@@ -14,11 +15,13 @@ from dune.perftool.generation import (backend,
                                       temporary_variable,
                                       transform,
                                       )
+from dune.perftool.loopy.flatten import flatten_index
 from dune.perftool.loopy.buffer import (get_buffer_temporary,
                                         initialize_buffer,
                                         switch_base_storage,
                                         )
 from dune.perftool.sumfact.quadrature import nest_quadrature_loops
+from dune.perftool.pdelab.localoperator import determine_accumulation_space
 from dune.perftool.pdelab.spaces import name_lfs
 from dune.perftool.sumfact.amatrix import (AMatrix,
                                            quadrature_points_per_direction,
@@ -38,10 +41,6 @@ from loopy.symbolic import FunctionIdentifier
 from pytools import product
 
 
-class AccumulationArg(GlobalArg):
-    allowed_extra_kwargs = GlobalArg.allowed_extra_kwargs + ["lfs"]
-
-
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
@@ -111,21 +110,33 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                                 additional_inames=frozenset(visitor.inames),
                                                 )
 
-    # Now write all this into the correct residual
-    lfs = name_lfs(accterm.argument.argexpr.ufl_element(),
-                   accterm.argument.restriction,
-                   accterm.argument.component,
-                   )
     inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
+
+    # Collect the lfs and lfs indices for the accumulate call
+    test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
+    test_lfs.index = flatten_index(tuple(Variable(i) for i in inames),
+                                   (basis_functions_per_direction(),) * dim
+                                   )
+
+    # In the jacobian case, also determine the space for the ansatz space
+    ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure)
+    rank = 2 if visitor.inames else 1
+    if rank == 2:
+        ansatz_lfs.index = flatten_index(tuple(Variable(i) for i in visitor.inames),
+                                         (basis_functions_per_direction(),) * dim
+                                         )
+
+    # Construct the expression representing "{r,jac}.accumulate(..)"
     accum = name_accumulation_variable()
-    globalarg(accum,
-              shape=(basis_functions_per_direction(),) * dim,
-              argtype=AccumulationArg,
-              lfs=lfs,
-              )
-
-    instruction(expression=Subscript(Variable(result), tuple(Variable(i) for i in inames)),
-                assignee=Subscript(Variable(accum), tuple(Variable(i) for i in inames)),
+    expr = Call(PDELabAccumulationFunction(accum, rank),
+                (ansatz_lfs.get_args() +
+                 test_lfs.get_args() +
+                 (Subscript(Variable(result), tuple(Variable(i) for i in inames)),)
+                 )
+                )
+
+    instruction(assignees=(),
+                expression=expr,
                 forced_iname_deps=frozenset(inames + visitor.inames),
                 forced_iname_deps_is_final=True,
                 depends_on=insn_dep,
-- 
GitLab