From f57ec6b269498220d2e37ac26785760cf0ec867c Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 12 Oct 2016 15:38:36 +0200
Subject: [PATCH] Fix some stokes kernels

---
 python/dune/perftool/pdelab/localoperator.py         | 12 +++++++-----
 .../transformations/extract_accumulation_terms.py    | 11 +++++++----
 .../ufl/transformations/identitypropagation.py       | 11 ++++++-----
 3 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 068cae0e..71f552cc 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -353,11 +353,10 @@ def boundary_predicates(expr, measure, subdomain_id):
 
 
 @iname
-def grad_iname(ma):
+def grad_iname(index, dim):
     from dune.perftool.pdelab import name_index
-    from ufl.domain import find_geometric_dimension
-    name = name_index(ma.index)
-    domain(name, find_geometric_dimension(ma.expr))
+    name = name_index(index)
+    domain(name, dim)
     return name
 
 
@@ -368,7 +367,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     # If this is a gradient, we generate an iname
     additional_inames = frozenset({})
     if accterm.argument.index:
-        additional_inames = frozenset({grad_iname(accterm.argument)})
+        from ufl.domain import find_geometric_dimension
+        dim = find_geometric_dimension(accterm.argument.expr)
+        for i in accterm.argument.index._indices:
+            additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
 
     # It may happen that an entire accumulation term vanishes. We do nothing in that case
     if pymbolic_expr == 0:
diff --git a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
index 7fdf46b3..3d53f9c2 100644
--- a/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/transformations/extract_accumulation_terms.py
@@ -44,10 +44,13 @@ def split_into_accumulation_terms(expr):
         replace_expr = replace_expression(expr, replacemap=replacement)
 
         # 2) Cut the test function itself from the expression
-        if test_arg.reference_grad:
-            newi = indices(1)
-            replacement = {test_arg.expr: Indexed(Identity(2), MultiIndex(newi + test_arg.index._indices))}
-            test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={test_arg.index[0]: newi[0]}))
+        if test_arg.index:
+            newi = indices(len(test_arg.index))
+            identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,))) for i, j in zip(newi, test_arg.index._indices))
+            from dune.perftool.ufl.flatoperators import construct_binary_operator
+            from ufl.classes import Product
+            replacement = {test_arg.expr: construct_binary_operator(identities, Product)}
+            test_arg = ModifiedArgumentDescriptor(reindexing(test_arg.expr, replacemap={i: j for i, j in zip(test_arg.index._indices, newi)}))
         else:
             replacement = {test_arg.expr: IntValue(1)}
         replace_expr = replace_expression(replace_expr, replacemap=replacement)
diff --git a/python/dune/perftool/ufl/transformations/identitypropagation.py b/python/dune/perftool/ufl/transformations/identitypropagation.py
index 2e7bbab5..aeb21a10 100644
--- a/python/dune/perftool/ufl/transformations/identitypropagation.py
+++ b/python/dune/perftool/ufl/transformations/identitypropagation.py
@@ -14,7 +14,7 @@ class GetIndexMap(MultiFunction):
     call = MultiFunction.__call__
 
     def __call__(self, o):
-        self.free_index = Index(o.ufl_free_indices[0])
+        self.free_indices = frozenset(Index(i) for i in o.ufl_free_indices)
         self.replacemap = {}
         self.call(o)
         return self.replacemap
@@ -26,10 +26,11 @@ class GetIndexMap(MultiFunction):
     def indexed(self, o):
         op, i = o.ufl_operands
         if isinstance(op, Identity):
-            assert(len(i) == 2)
-            assert(self.free_index in i)
-            ind, = set(i) - {self.free_index}
-            self.replacemap[ind] = self.free_index
+            free_index = self.free_indices.intersection(frozenset(i))
+            assert(len(free_index) == 1)
+            ind, = frozenset(i) - free_index
+            free_index, = free_index
+            self.replacemap[ind] = free_index
         else:
             self.call(op)
 
-- 
GitLab