From eb28ceb1fefb7412c7b32d11eca32799c23ff96a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 21 Jul 2016 17:44:10 +0200
Subject: [PATCH] New take on dimension indices, WIP

---
 python/dune/perftool/loopy/transformer.py    | 21 +++----------
 python/dune/perftool/pdelab/localoperator.py |  5 ++-
 python/dune/perftool/ufl/dimensionindex.py   | 32 ++++++++++++++++++++
 3 files changed, 41 insertions(+), 17 deletions(-)
 create mode 100644 python/dune/perftool/ufl/dimensionindex.py

diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 90049a40..ba72d430 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -35,17 +35,12 @@ from dune.perftool.pdelab.quadrature import quadrature_iname
 from pymbolic.primitives import Subscript, Variable
 
 
-@iname
-def index_sum_iname(i):
-    from dune.perftool.pdelab import name_index
-    return name_index(i)
-
-
 class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper):
-    def __init__(self, measure, subdomain_id):
+    def __init__(self, measure, subdomain_id, dimension_index_aliases):
         # Some variables describing the integral measure of this integral
         self.measure = measure
         self.subdomain_id = subdomain_id
+        self.dimension_index_aliases = dimension_index_aliases
 
         # Call base class constructors
         super(UFL2LoopyVisitor, self).__init__()
@@ -55,7 +50,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
         self.argshape = 0
         self.redinames = ()
         self.inames = []
-        self.dimension_index_aliases = []
         self.substitution_rules = []
 
         # Initialize the local function spaces that we might need for this term
@@ -265,10 +259,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
         use_indices = self.last_index[self.argshape:]
 
-        for i in range(self.argshape):
-            self.dimension_index_aliases.append(i)
-#             self.index_placeholder_removal_mapper.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg'))
-
         self.argshape = 0
         if isinstance(aggr, Subscript):
             return Subscript(aggr.aggregate, aggr.index + use_indices)
@@ -287,7 +277,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
         # Get the iname for the reduction index
         ind = o.ufl_operands[1][0]
-        self.redinames = self.redinames + (index_sum_iname(ind),)
+        self.redinames = self.redinames + (ind,)
         shape = o.ufl_operands[0].ufl_index_dimensions[0]
         from dune.perftool.pdelab import name_index
         domain(name_index(ind), shape)
@@ -301,11 +291,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
             # Recurse to get the summation expression
             term = self.call(o.ufl_operands[0])
-
             self.redinames = tuple(i for i in self.redinames if i not in self.dimension_index_aliases)
-
             if len(self.redinames) > 0:
-                ret = Reduction("sum", self.redinames, term)
+                ret = Reduction("sum", tuple(name_index(ind) for ind in self.redinames), term)
                 name = get_temporary_name()
                 # Generate a substitution rule for this one.
                 from loopy import SubstitutionRule
@@ -333,6 +321,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
             from dune.perftool.pdelab import name_index
             if index in self.dimension_index_aliases:
                 from dune.perftool.pdelab.geometry import dimension_iname
+                self.inames.append(dimension_iname(context='arg'))
                 return Variable(dimension_iname(context='arg'))
             else:
                 return Variable(name_index(index))
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 310dc616..28c46c33 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -161,13 +161,16 @@ def generate_kernel(integrals):
         subdomain_id = integral.subdomain_id()
         subdomain_data = integral.subdomain_data()
 
+        from dune.perftool.ufl.dimensionindex import collect_dimension_index_aliases
+        dimension_index_aliases = collect_dimension_index_aliases(integrand)
+
         # Now split the given integrand into accumulation expressions
         from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
         accterms = split_into_accumulation_terms(integrand)
 
         # Get a transformer instance for this kernel
         from dune.perftool.loopy.transformer import UFL2LoopyVisitor
-        visitor = UFL2LoopyVisitor(measure, subdomain_id)
+        visitor = UFL2LoopyVisitor(measure, subdomain_id, dimension_index_aliases)
 
         # Iterate over the terms and generate a kernel
         for term in accterms:
diff --git a/python/dune/perftool/ufl/dimensionindex.py b/python/dune/perftool/ufl/dimensionindex.py
new file mode 100644
index 00000000..1446bd4d
--- /dev/null
+++ b/python/dune/perftool/ufl/dimensionindex.py
@@ -0,0 +1,32 @@
+""" Extract all the aliases of dimension indices """
+
+from ufl.algorithms import MultiFunction
+
+
+class _CollectDimensionIndexAliases(MultiFunction):
+    call = MultiFunction.__call__
+
+    def __call__(self, o):
+        self.shape = 0
+        return self.call(o)
+
+    def expr(self, o):
+        return frozenset({}).union(*tuple(self.call(op) for op in o.ufl_operands))
+
+    def terminal(self, o):
+        return frozenset({})
+
+    def function_view(self, o):
+        self.shape = len(o.ufl_operands[1])
+        return frozenset({})
+
+    def indexed(self, o):
+        ret = self.call(o.ufl_operands[0])
+        if self.shape:
+            ret = ret.union(frozenset({o.ufl_operands[1][:self.shape][0]}))
+        self.shape = 0
+        return ret
+
+
+def collect_dimension_index_aliases(expr):
+    return _CollectDimensionIndexAliases()(expr)
-- 
GitLab