From 93af3b0b5cb2c5105603560ad98d68eb0f7fe2f2 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 24 Apr 2017 17:20:05 +0200
Subject: [PATCH] Cleanup accumulation.py

---
 python/dune/perftool/sumfact/accumulation.py | 86 ++++++--------------
 1 file changed, 23 insertions(+), 63 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index b508071d..d65a20b8 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -67,7 +67,7 @@ class AlreadyAssembledInput(SumfactKernelInputBase, ImmutableRecord):
 
 @backend(interface="accum_insn", name="sumfact")
 def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
-    # When doing sumfactorization we want to split the test function
+    # When doing sum factorization we want to split the test function
     assert(accterm.argument.expr is not None)
 
     # Do the tree traversal to get a pymbolic expression representing this expression
@@ -75,16 +75,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     if pymbolic_expr == 0:
         return
 
-    # If this is a gradient, we find the gradient iname
     dim = world_dimension()
-    additional_inames = frozenset({})
-    if accterm.new_indices is not None:
-        for i in accterm.new_indices:
-            if i not in visitor.dimension_indices:
-                from dune.perftool.pdelab.localoperator import grad_iname
-                additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
-
-    # Get the degree of the element corresponding to this modified argument
     mod_arg_expr = accterm.argument.expr
     from ufl.classes import FunctionView, Argument
     while (not isinstance(mod_arg_expr, FunctionView)) and (not isinstance(mod_arg_expr, Argument)):
@@ -92,52 +83,40 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     degree = mod_arg_expr.ufl_element()._degree
     basis_size = degree + 1
 
-    def emit_sumfact_kernel(indices, restriction, insn_dep):
-        # Not implemented
-        if indices:
-            assert len(indices) <= 2
-
-        # Figure out the name of the accumulation variable
-        accum_idims = None
-        if indices is not None:
-            accum_idims = (indices[0],)
-        test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure, idims=accum_idims)
-        ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure, idims=accum_idims)
-        accum = name_accumulation_variable(test_lfs.get_restriction() + ansatz_lfs.get_restriction())
+    # Extract index information
+    grad_index = None
+    if accterm.argument.reference_grad:
+        grad_index = accterm.argument.expr.ufl_operands[1][0]._value
 
-        # TODO: Adjust for stokes sumfact symdiff
-        jacobian_inames = tuple()
-        if accterm.is_jacobian:
-            jacobian_inames = visitor.inames
+    accum_index = None
+    if visitor.indices:
+        accum_index = visitor.indices[0]
 
-        # Determine the derivative direction
-        derivative = None
-        if accterm.new_indices:
-            derivative = indices[-1]
-        if isinstance(accterm.argument.expr, uc.Indexed):
-            derivative = accterm.argument.expr.ufl_operands[1][0]._value
+    jacobian_inames = tuple()
+    if accterm.is_jacobian:
+        jacobian_inames = visitor.inames
+
+    def emit_sumfact_kernel(restriction, insn_dep):
+        test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure, idims=(accum_index,))
+        ansatz_lfs = determine_accumulation_space(accterm.term, 1, measure, idims=(accum_index,))
+        accum = name_accumulation_variable(test_lfs.get_restriction() + ansatz_lfs.get_restriction())
 
         # Construct the matrix sequence for this sum factorization
         matrix_sequence = construct_basis_matrix_sequence(
             transpose=True,
-            derivative=derivative,
+            derivative=grad_index,
             facedir=get_facedir(accterm.argument.restriction),
             facemod=get_facemod(accterm.argument.restriction),
             basis_size=basis_size)
 
-        # Avoid caching issues by passing the coeff_func_index
-        coeff_func_index = None
-        if indices and len(indices) == 2:
-            coeff_func_index = indices[0]
-
         # TODO: Adapt preferred position for stokes sumfact symdiff
         sf = SumfactKernel(matrix_sequence=matrix_sequence,
                            restriction=(accterm.argument.restriction, restriction),
                            stage=3,
-                           preferred_position=indices[-1] if accterm.new_indices else None,
+                           preferred_position=grad_index,
                            accumvar=accum,
                            within_inames=jacobian_inames,
-                           input=AlreadyAssembledInput(index=coeff_func_index),
+                           input=AlreadyAssembledInput(index=accum_index),
                            )
 
         from dune.perftool.sumfact.vectorization import attach_vectorization_info
@@ -167,19 +146,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                         tags=frozenset(["quadvec", "gradvec"])
                         )
 
-        # Replace gradient iname with correct index for assignement
-        replace_dict = {}
-
-        # If we have two indices the first belongs to the dimension
-        # and the second one to the derivative. We handle all these
-        # indices by hand and replace them accordingly.
-        if indices and len(indices) == 2:
-                replace_dict['idim_arg0'] = indices[0]
-        for iname in additional_inames:
-            replace_dict[prim.Variable(iname)] = indices[-1]
-        from dune.perftool.loopy.symbolic import substitute
-        expression = substitute(pymbolic_expr, replace_dict)
-
         # Write timing stuff for jacobian (for alpha methods it is done at the end of stage 1)
         timer_dep = frozenset()
         if get_option("instrumentation_level") >= 4:
@@ -194,14 +160,14 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         from loopy.match import Or, Writes
         from loopy.symbolic import DependencyMapper
         from dune.perftool.tools import get_pymbolic_basename
-        deps = Or(tuple(Writes(get_pymbolic_basename(expr)) for expr in DependencyMapper()(expression)))
+        deps = Or(tuple(Writes(get_pymbolic_basename(expr)) for expr in DependencyMapper()(pymbolic_expr)))
 
         # Issue an instruction in the quadrature loop that fills the buffer
         # with the evaluation of the contribution at all quadrature points
         assignee = prim.Subscript(lp.TaggedVariable(temp, vsf.tag),
                                   vsf.quadrature_index(sf))
         contrib_dep = instruction(assignee=assignee,
-                                  expression=expression,
+                                  expression=pymbolic_expr,
                                   forced_iname_deps=frozenset(quadrature_inames() + jacobian_inames),
                                   forced_iname_deps_is_final=True,
                                   tags=frozenset({"quadvec"}).union(vectag),
@@ -243,7 +209,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         # Determine the expression to accumulate with. This depends on the vectorization strategy!
         result = prim.Subscript(result, tuple(prim.Variable(i) for i in inames))
         vecinames = ()
-        # TODO: evaluate whether the following line would be okay with vsf.vectorized
+
         if vsf.vectorized:
             iname = accum_iname((accterm.argument.restriction, restriction), vsf.vector_width, "vec")
             vecinames = (iname,)
@@ -285,10 +251,4 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
     insn_dep = None
     for restriction in jac_restrictions:
-        if accterm.new_indices:
-            shape = (world_dimension(),) * len(accterm.new_indices)
-            # Iterate over all combinations of indices for this shape
-            for indices in itertools.product(*map(range, shape)):
-                insn_dep = emit_sumfact_kernel(indices, restriction, insn_dep)
-        else:
-            insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
+        insn_dep = emit_sumfact_kernel(restriction, insn_dep)
-- 
GitLab