From 16acae789589ba41e85b4be6fbbe9dff390ebbd0 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 20 Jun 2016 13:20:24 +0200
Subject: [PATCH] Refactor UFL2LoopyVisitor to have state

targetting better indexing logic.
---
 python/dune/perftool/loopy/transformer.py    | 37 ++++++++++----------
 python/dune/perftool/pdelab/localoperator.py |  7 ++--
 python/dune/perftool/pymbolic/placeholder.py | 12 +++++--
 3 files changed, 33 insertions(+), 23 deletions(-)

diff --git a/python/dune/perftool/loopy/transformer.py b/python/dune/perftool/loopy/transformer.py
index 8c12b261..8e863694 100644
--- a/python/dune/perftool/loopy/transformer.py
+++ b/python/dune/perftool/loopy/transformer.py
@@ -55,13 +55,16 @@ def get_outerloop():
 
 class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapper):
     def __init__(self, measure, subdomain_id):
+        # Some variables describing the integral measure of this integral
         self.measure = measure
         self.subdomain_id = subdomain_id
-        self.argshape = 0
-        self.redinames = ()
-        self.index_replacement_map = {}
+
+        # Call base class constructors
         super(UFL2LoopyVisitor, self).__init__()
 
+        # Some state variables that need to be persistent over multiple calls
+        self.index_placeholder_removal_mapper = IndexPlaceholderRemoval()
+
     def _assign(self, o):
         # In some corner cases we do not even need a temporary variable
         if isinstance(o, int) or isinstance(o, float):
@@ -90,13 +93,13 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
         # Change the assignee!
         if not merge_into_main_loopnest:
             assignee_index_placeholder = LFSIndexPlaceholderExtraction()(o).pop()
-            assignee_index = IndexPlaceholderRemoval(duplicate_inames=True)(assignee_index_placeholder)
+            assignee_index = self.index_placeholder_removal_mapper(assignee_index_placeholder, duplicate_inames=True)
             assignee = Subscript(assignee, (assignee_index,))
             temp_shape = (name_lfs_bound(name_leaf_lfs(assignee_index_placeholder.element, assignee_index_placeholder.restriction)),)
 
         # Now introduce duplicate inames for the argument loop if necessary
-        replaced_iname_deps = [IndexPlaceholderRemoval(duplicate_inames=not merge_into_main_loopnest, wrap_in_variable=False, index_replacement_map=self.index_replacement_map)(i) for i in iname_deps]
-        replaced_expr = IndexPlaceholderRemoval(duplicate_inames=not merge_into_main_loopnest, index_replacement_map=self.index_replacement_map)(o)
+        replaced_iname_deps = [self.index_placeholder_removal_mapper(i, duplicate_inames=not merge_into_main_loopnest, wrap_in_variable=False) for i in iname_deps]
+        replaced_expr = self.index_placeholder_removal_mapper(o, duplicate_inames=not merge_into_main_loopnest)
 
         # Now we assign this expression to a new temporary variable
         insn_id = instruction(assignee=assignee,
@@ -111,7 +114,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
         retvar = Variable(temp)
         if not merge_into_main_loopnest:
-            retvar_index = IndexPlaceholderRemoval(index_replacement_map=self.index_replacement_map)(assignee_index_placeholder)
+            retvar_index = self.index_placeholder_removal_mapper(assignee_index_placeholder)
             retvar = Subscript(retvar, (retvar_index,))
 
         # Now that we know its exact name, declare the temporary
@@ -120,6 +123,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
         return retvar
 
     def __call__(self, o):
+        # Reset some state variables that are reinitialized for each accumulation term
+        self.argshape = 0
+        self.redinames = ()
+        self.inames = []
+
         # Initialize the local function spaces that we might need for this term
         # We therefore gather a list of modified trial functions too.
         from dune.perftool.ufl.modified_terminals import extract_modified_arguments
@@ -137,9 +145,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
         # Determine the rank of the term
         self.rank = len(test_ma)
 
-        # And initialize a list of found inames
-        self.inames = []
-
         # First we do the tree traversal to get a pymbolic expression representing this expression
         pymbolic_expr = self.call(o)
 
@@ -152,7 +157,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
             pymbolic_expr = self._assign(pymbolic_expr)
 
         # Transform the IndexPlaceholders into real inames
-        self.inames = [IndexPlaceholderRemoval(wrap_in_variable=False)(i) for i in self.inames]
+        self.inames = [self.index_placeholder_removal_mapper(i, wrap_in_variable=False) for i in self.inames]
 
         # Collect the arguments for the accumulate function
         accumargs = [None] * (2 * len(test_ma))
@@ -329,7 +334,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
 
         for i in range(self.argshape):
             from dune.perftool.pdelab.geometry import dimension_iname
-            self.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg'))
+            self.index_placeholder_removal_mapper.index_replacement_map[self.last_index[i].expr] = Variable(dimension_iname(context='arg'))
 
         self.argshape = 0
         if isinstance(aggr, Subscript):
@@ -368,11 +373,11 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
             term = self.call(o.ufl_operands[0])
 
             from dune.perftool.pymbolic.inameset import get_index_inames
-            used_inames = frozenset({self.index_replacement_map.get(i, i).name for i in get_index_inames(term, as_variables=True)})
+            used_inames = frozenset({self.index_placeholder_removal_mapper.index_replacement_map.get(i, i).name for i in get_index_inames(term, as_variables=True)})
             self.inames = [i for i in used_inames.intersection(frozenset({i for i in oldinames}))] + self.inames
 
             # Now filter all those reduction inames that are marked for removal
-            implicit_inames = [i.name for i in self.index_replacement_map]
+            implicit_inames = [i.name for i in self.index_placeholder_removal_mapper.index_replacement_map]
             self.redinames = tuple(i for i in self.redinames if i not in implicit_inames)
 
             if len(self.redinames) > 0:
@@ -395,7 +400,3 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker, UFL2PymbolicMapper, GeometryMapp
         # One might as well take the uflname as string here, but I apply this function
         from dune.perftool.pdelab import name_index
         return IndexPlaceholder(Variable(name_index(o)))
-
-
-def transform_accumulation_term(term, measure, subdomain_id):
-    return UFL2LoopyVisitor(measure, subdomain_id)(term)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index a7c28b06..596b1414 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -165,10 +165,13 @@ def generate_kernel(integrals):
         from dune.perftool.ufl.transformations.extract_accumulation_terms import split_into_accumulation_terms
         accterms = split_into_accumulation_terms(integrand)
 
+        # Get a transformer instance for this kernel
+        from dune.perftool.loopy.transformer import UFL2LoopyVisitor
+        visitor = UFL2LoopyVisitor(measure, subdomain_id)
+
         # Iterate over the terms and generate a kernel
         for term in accterms:
-            from dune.perftool.loopy.transformer import transform_accumulation_term
-            transform_accumulation_term(term, measure, subdomain_id)
+            visitor(term)
 
     # Extract the information, which is needed to create a loopy kernel.
     # First extracting it, might be useful to alter it before kernel generation.
diff --git a/python/dune/perftool/pymbolic/placeholder.py b/python/dune/perftool/pymbolic/placeholder.py
index 7903c736..05e4787f 100644
--- a/python/dune/perftool/pymbolic/placeholder.py
+++ b/python/dune/perftool/pymbolic/placeholder.py
@@ -32,11 +32,17 @@ class LFSIndexPlaceholder(IndexPlaceholderBase):
 
 
 class IndexPlaceholderRemoval(IdentityMapper):
-    def __init__(self, wrap_in_variable=True, duplicate_inames=False, index_replacement_map={}):
+    def __init__(self):
+        # Initialize base class
+        super(IndexPlaceholderRemoval, self).__init__()
+
+        # Initialize state variables that are persistent over multiple calls
+        self.index_replacement_map = {}
+
+    def __call__(self, o, wrap_in_variable=True, duplicate_inames=False):
         self.duplicate_inames = duplicate_inames
         self.wrap_in_variable = wrap_in_variable
-        self.index_replacement_map = index_replacement_map
-        super(IndexPlaceholderRemoval, self).__init__()
+        return self.rec(o)
 
     def map_foreign(self, o):
         # How do we map constants here? map_constant was not correct
-- 
GitLab