From d22f81e7db61d56516197e4efcd37093db383251 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 18 Oct 2016 13:16:08 +0200
Subject: [PATCH] Fix jacobians of systems and use pytools.Record for modified
 arguments

---
 python/dune/perftool/pdelab/argument.py       |  1 -
 python/dune/perftool/pdelab/localoperator.py  | 24 +++--
 python/dune/perftool/pdelab/spaces.py         |  4 +-
 .../dune/perftool/ufl/modified_terminals.py   | 91 ++++++++++---------
 .../extract_accumulation_terms.py             | 42 ++++++---
 5 files changed, 93 insertions(+), 69 deletions(-)

diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py
index d62f884d..a55a485c 100644
--- a/python/dune/perftool/pdelab/argument.py
+++ b/python/dune/perftool/pdelab/argument.py
@@ -14,7 +14,6 @@ from dune.perftool.generation import (domain,
                                       valuearg,
                                       get_global_context_value
                                       )
-from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
 from dune.perftool.pdelab import (name_index,
                                   restricted_name,
                                   )
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 71f552cc..d7a2e545 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -279,6 +279,7 @@ def determine_accumulation_space(expr, number, measure):
     if len(args) == 0:
         return AccumulationSpace()
 
+    # There should be but one modified argument, as the splitting eliminated all others.
     assert(len(args) == 1)
     ma, = args
 
@@ -295,7 +296,7 @@ def determine_accumulation_space(expr, number, measure):
 
     if len(subel.value_shape()) != 0:
         from dune.perftool.pdelab.geometry import dimension_iname
-        idims = tuple(dimension_iname(context='arg', count=number) for i in range(len(subel.value_shape())))
+        idims = tuple(dimension_iname(context='arg', count=i) for i in range(len(subel.value_shape())))
         lfs = lfs_child(lfs, idims, shape=subel.value_shape(), symmetry=subel.symmetry)
         subel = subel.sub_elements()[0]
 
@@ -370,7 +371,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         from ufl.domain import find_geometric_dimension
         dim = find_geometric_dimension(accterm.argument.expr)
         for i in accterm.argument.index._indices:
-            additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
+            if i not in visitor.dimension_indices:
+                additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
 
     # It may happen that an entire accumulation term vanishes. We do nothing in that case
     if pymbolic_expr == 0:
@@ -441,20 +443,22 @@ def generate_kernel(integrals):
             from dune.perftool.pdelab.spaces import traverse_lfs_tree
             traverse_lfs_tree(ma)
 
-        from dune.perftool.options import set_option
-        set_option('print_transformations', True)
-        set_option('print_transformations_dir', '.')
-
         # Now split the given integrand into accumulation expressions
         from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
         accterms = split_into_accumulation_terms(integrand)
 
-        # Get a transformer instance for this kernel
-        from dune.perftool.ufl.visitor import UFL2LoopyVisitor
-        visitor = UFL2LoopyVisitor(measure, dimension_indices)
-
         # Iterate over the terms and generate a kernel
         for term in accterms:
+            # Adjust the index map for the visitor
+            from copy import deepcopy
+            indexmap = deepcopy(dimension_indices)
+            for i, j in term.indexmap.items():
+                if i in indexmap:
+                    indexmap[j] = indexmap[i]
+
+            # Get a transformer instance for this kernel
+            from dune.perftool.ufl.visitor import UFL2LoopyVisitor
+            visitor = UFL2LoopyVisitor(measure, indexmap)
             generate_accumulation_instruction(visitor, term, measure, subdomain_id)
 
     # Extract the information, which is needed to create a loopy kernel.
diff --git a/python/dune/perftool/pdelab/spaces.py b/python/dune/perftool/pdelab/spaces.py
index 3812cc1f..e439425f 100644
--- a/python/dune/perftool/pdelab/spaces.py
+++ b/python/dune/perftool/pdelab/spaces.py
@@ -161,8 +161,8 @@ def type_gfs(element, basetype=None, index_stack=None):
 
 
 def traverse_lfs_tree(arg):
-    from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
-    assert isinstance(arg, ModifiedArgumentDescriptor)
+    from dune.perftool.ufl.modified_terminals import ModifiedArgument
+    assert isinstance(arg, ModifiedArgument)
 
     # First we need to determine the basename as given in the signature of
     # this kernel method!
diff --git a/python/dune/perftool/ufl/modified_terminals.py b/python/dune/perftool/ufl/modified_terminals.py
index 1bed3056..33319b27 100644
--- a/python/dune/perftool/ufl/modified_terminals.py
+++ b/python/dune/perftool/ufl/modified_terminals.py
@@ -3,6 +3,30 @@
 from ufl.algorithms import MultiFunction
 from dune.perftool import Restriction
 from ufl.classes import MultiIndex
+from pytools import Record
+
+
+class ModifiedArgument(Record):
+    def __init__(self,
+                 expr=None,
+                 argexpr=None,
+                 grad=False,
+                 index=None,
+                 reference_grad=False,
+                 restriction=Restriction.NONE,
+                 component=MultiIndex(()),
+                 reference=False,
+                 ):
+        Record.__init__(self,
+                        expr=expr,
+                        argexpr=argexpr,
+                        grad=grad,
+                        index=index,
+                        reference_grad=reference_grad,
+                        restriction=restriction,
+                        component=component,
+                        reference=reference,
+                        )
 
 
 class ModifiedTerminalTracker(MultiFunction):
@@ -10,6 +34,9 @@ class ModifiedTerminalTracker(MultiFunction):
     grad, reference_grad, positive_restricted and negative_restricted.
     The appearance of those classes changes the internal state of the MF.
     """
+
+    call = MultiFunction.__call__
+
     def __init__(self):
         MultiFunction.__init__(self)
         self.grad = False
@@ -59,57 +86,33 @@ class ModifiedTerminalTracker(MultiFunction):
         return ret
 
 
-class ModifiedArgumentDescriptor(MultiFunction):
-    def __init__(self, e):
-        MultiFunction.__init__(self)
-
-        self.grad = False
-        self.reference = False
-        self.reference_grad = False
+class ModifiedArgumentAnalysis(ModifiedTerminalTracker):
+    def __init__(self):
         self.index = None
-        self.restriction = Restriction.NONE
-        self.component = MultiIndex(())
-        self.expr = e
-
-        self.__call__(e)
-        self.__call__ = None
+        ModifiedTerminalTracker.__init__(self)
 
-    def __eq__(self, other):
-        return self.expr == other.expr
-
-    def grad(self, o):
-        self.grad = True
-        self(o.ufl_operands[0])
-
-    def reference_grad(self, o):
-        self.reference_grad = True
-        self(o.ufl_operands[0])
-
-    def reference_value(self, o):
-        self.reference = True
-        self(o.ufl_operands[0])
-
-    def positive_restricted(self, o):
-        self.restriction = Restriction.POSITIVE
-        self(o.ufl_operands[0])
-
-    def negative_restricted(self, o):
-        self.restriction = Restriction.NEGATIVE
-        self(o.ufl_operands[0])
+    def __call__(self, o):
+        self.call_expr = o
+        return self.call(o)
 
     def indexed(self, o):
         self.index = o.ufl_operands[1]
-        self(o.ufl_operands[0])
+        return self.call(o.ufl_operands[0])
 
-    def function_view(self, o):
-        self.component = o.ufl_operands[1]
-        self(o.ufl_operands[0])
+    def form_argument(self, o):
+        return ModifiedArgument(expr=self.call_expr,
+                                argexpr=o,
+                                index=self.index,
+                                restriction=self.restriction,
+                                component=self.component,
+                                grad=self.grad,
+                                reference_grad=self.reference_grad,
+                                reference=self.reference,
+                                )
 
-    def argument(self, o):
-        self.argexpr = o
 
-    def coefficient(self, o):
-        self.argexpr = o
+def analyse_modified_argument(expr):
+    return ModifiedArgumentAnalysis()(expr)
 
 
 class _ModifiedArgumentExtractor(MultiFunction):
@@ -123,7 +126,7 @@ class _ModifiedArgumentExtractor(MultiFunction):
         if ret:
             # This indicates that this entire expression was a modified thing...
             self.modified_arguments.add(ret)
-        return tuple(ModifiedArgumentDescriptor(ma) for ma in self.modified_arguments)
+        return tuple(analyse_modified_argument(ma) for ma in self.modified_arguments)
 
     def expr(self, o):
         for op in o.ufl_operands:
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index 3d53f9c2..0429c1c9 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -7,7 +7,7 @@ from dune.perftool.ufl.transformations import ufl_transformation
 from dune.perftool.ufl.transformations.replace import replace_expression
 from dune.perftool.ufl.transformations.identitypropagation import identity_propagation
 from dune.perftool.ufl.transformations.reindexing import reindexing
-from dune.perftool.ufl.modified_terminals import ModifiedArgumentDescriptor
+from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
 
 from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex
 from ufl.core.multiindex import indices
@@ -16,12 +16,21 @@ from pytools import Record
 
 
 class AccumulationTerm(Record):
-    def __init__(self, term, argument):
-        Record.__init__(self, term=term, argument=argument)
+    def __init__(self,
+                 term,
+                 argument,
+                 indexmap={},
+                 ):
+        assert isinstance(argument, ModifiedArgument)
+        Record.__init__(self,
+                        term=term,
+                        argument=argument,
+                        indexmap=indexmap,
+                        )
 
 
 @ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
-def split_into_accumulation_terms(expr):
+def split_into_accumulation_terms(expr, indexmap={}):
     ret = []
 
     # Extract a list of modified terminals for the test function
@@ -29,8 +38,7 @@ def split_into_accumulation_terms(expr):
     test_args = extract_modified_arguments(expr, argnumber=0)
 
     # Extract a list of modified terminals for the ansatz function
-    # in jacobian forms. Only the restriction of those terminals will
-    # be used to generate new accumulation terms!
+    # in jacobian forms.
     all_jacobian_args = extract_modified_arguments(expr, argnumber=1)
 
     for test_arg in test_args:
@@ -39,18 +47,23 @@ def split_into_accumulation_terms(expr):
 
         # 1) We first cut the expression to the relevant modified test_function
         # Build a replacement dictionary
-        replacement = {ma.expr: Zero() for ma in test_args}
+        replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
+                                     free_indices=ma.expr.ufl_free_indices,
+                                     index_dimensions=ma.expr.ufl_index_dimensions)
+                       for ma in test_args}
         replacement[test_arg.expr] = test_arg.expr
         replace_expr = replace_expression(expr, replacemap=replacement)
 
         # 2) Cut the test function itself from the expression
+        indexmap = {}
         if test_arg.index:
             newi = indices(len(test_arg.index))
             identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
+            indexmap = {i: j for i, j in zip(test_arg.index._indices, newi)}
             from dune.perftool.ufl.flatoperators import construct_binary_operator
             from ufl.classes import Product
             replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
-            test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)}))
+            test_arg = analyse_modified_argument(reindexing(test_arg.expr, replacemap=indexmap))
         else:
             replacement = {test_arg.expr: IntValue(1)}
         replace_expr = replace_expression(replace_expr, replacemap=replacement)
@@ -60,16 +73,21 @@ def split_into_accumulation_terms(expr):
 
         # 4) Further split according to trial function in jacobian terms
         if all_jacobian_args:
-            for jac_arg in all_jacobian_args:
+            # Update the list!
+            jac_args = extract_modified_arguments(replace_expr, argnumber=1)
+            for jac_arg in jac_args:
                 # TODO Some jacobian terms can be joined
-                replacement = {ma.expr: Zero() for ma in all_jacobian_args}
+                replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
+                                             free_indices=ma.expr.ufl_free_indices,
+                                             index_dimensions=ma.expr.ufl_index_dimensions)
+                               for ma in jac_args}
                 replacement[jac_arg.expr] = jac_arg.expr
                 jac_expr = replace_expression(replace_expr, replacemap=replacement)
 
                 if not isinstance(jac_expr, Zero):
-                    ret.append(AccumulationTerm(jac_expr, test_arg))
+                    ret.append(AccumulationTerm(jac_expr, test_arg, indexmap))
         else:
             if not isinstance(replace_expr, Zero):
-                ret.append(AccumulationTerm(replace_expr, test_arg))
+                ret.append(AccumulationTerm(replace_expr, test_arg, indexmap))
 
     return ret
-- 
GitLab