From db2b3524be04c767d32bd69c099167d422d6cc2d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Wed, 29 Mar 2017 15:41:43 +0200
Subject: [PATCH] Simplify dimension index handling

---
 python/dune/perftool/pdelab/localoperator.py      | 15 +++++----------
 .../perftool/ufl/extract_accumulation_terms.py    | 13 +++++++------
 2 files changed, 12 insertions(+), 16 deletions(-)

diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 24f7e937..33ffcc9c 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -380,10 +380,6 @@ def visit_integrals(integrals):
                 from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian
                 integrand = diagonal_jacobian(integrand)
 
-        # Gather dimension indices
-        from dune.perftool.ufl.dimensionindex import dimension_index_mapping
-        dimension_indices = dimension_index_mapping(integrand)
-
         # Generate code for the LFS trees present in the form
         from dune.perftool.ufl.modified_terminals import extract_modified_arguments
         test_ma = extract_modified_arguments(integrand, argnumber=0)
@@ -404,12 +400,11 @@ def visit_integrals(integrals):
 
         # Iterate over the terms and generate a kernel
         for accterm in accterms:
-            # Adjust the index map for the visitor
-            from copy import deepcopy
-            indexmap = deepcopy(dimension_indices)
-            for i, j in accterm.indexmap.items():
-                if i in indexmap:
-                    indexmap[j] = indexmap[i]
+            # Get dimension indices
+            from dune.perftool.ufl.dimensionindex import dimension_index_mapping
+            indexmap = dimension_index_mapping(accterm.test_arg())
+            # For jacobian there can also be dimension indices in the expression
+            indexmap.update(dimension_index_mapping(accterm.term))
 
             # Get a transformer instance for this kernel
             if get_option('sumfact'):
diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py
index 29ee9e78..11dbdcee 100644
--- a/python/dune/perftool/ufl/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/extract_accumulation_terms.py
@@ -22,17 +22,21 @@ class AccumulationTerm(Record):
     def __init__(self,
                  term,
                  argument,
-                 indexmap={},
                  new_indices=None
                  ):
         assert isinstance(argument, ModifiedArgument)
         Record.__init__(self,
                         term=term,
                         argument=argument,
-                        indexmap=indexmap,
                         new_indices=new_indices
                         )
 
+    def test_arg(self):
+        if self.new_indices is None:
+            return self.argument.expr
+        else:
+            return Indexed(self.argument.expr, MultiIndex(self.new_indices))
+
 
 @ufl_transformation(name="accterms", extraction_lambda=lambda l: [at.term for at in l])
 def split_into_accumulation_terms(expr):
@@ -153,7 +157,6 @@ def cut_accumulation_term(expr):
     #
     # In step 2 this will collaps to: a_{m,n} + b_{m,n}
     replacement = {}
-    indexmap = {}
     newi = None
     backmap = {}
     # Get all appearances of test functions with their indices
@@ -194,7 +197,6 @@ def cut_accumulation_term(expr):
                 backmap.update({mod_indexed_test_arg:
                                 Indexed(Identity(2), MultiIndex((newi[0], newi[1])))})
                 replacement.update({indexed_test_arg.expr: rep})
-                indexmap.update({indexed_test_arg.index[0]: newi[0]})
             else:
                 # Replace indexed test function with a product of identities.
                 identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
@@ -202,7 +204,6 @@ def cut_accumulation_term(expr):
                 replacement.update({indexed_test_arg.expr:
                                     construct_binary_operator(identities, Product)})
 
-                indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
         else:
             replacement.update({indexed_test_arg.expr: IntValue(1)})
     expr = replace_expression(expr, replacemap=replacement)
@@ -212,4 +213,4 @@ def cut_accumulation_term(expr):
     expr = identity_propagation(expr)
     expr = replace_expression(expr, replacemap=backmap)
 
-    return AccumulationTerm(expr, test_arg, indexmap, newi)
+    return AccumulationTerm(expr, test_arg, newi)
-- 
GitLab