From ca0c381f2ac55a304a2b0ae61c55f91b4b4eb3d3 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 25 Jul 2017 11:31:20 +0200
Subject: [PATCH] Cleanup the indexpushdown transformation

---
 python/dune/perftool/ufl/preprocess.py            |  8 ++++----
 .../perftool/ufl/transformations/indexpushdown.py | 15 +++++++++++----
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py
index df756bf7..0ee27f8a 100644
--- a/python/dune/perftool/ufl/preprocess.py
+++ b/python/dune/perftool/ufl/preprocess.py
@@ -25,13 +25,13 @@ def preprocess_form(form):
 
 
 def apply_default_transformations(form):
+    #
+    # This is the list of transformations we unconditionally apply to
+    # all forms we want to generate code for.
+    #
     from dune.perftool.ufl.transformations import transform_form
     from dune.perftool.ufl.transformations.indexpushdown import pushdown_indexed
-    from dune.perftool.ufl.transformations.reindexing import reindexing
-    from dune.perftool.ufl.transformations.unroll import unroll_dimension_loops
 
-#     form = transform_form(form, unroll_dimension_loops)
     form = transform_form(form, pushdown_indexed)
-#     form = transform_form(form, reindexing)
 
     return form
diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py
index 1f1e0b64..f70a6b53 100644
--- a/python/dune/perftool/ufl/transformations/indexpushdown.py
+++ b/python/dune/perftool/ufl/transformations/indexpushdown.py
@@ -1,9 +1,10 @@
 from __future__ import absolute_import
 from ufl.algorithms import MultiFunction
-from ufl.classes import Sum, Indexed
 from dune.perftool.ufl.flatoperators import get_operands, construct_binary_operator
 from dune.perftool.ufl.transformations import ufl_transformation
 
+import ufl.classes as uc
+
 
 class IndexPushDown(MultiFunction):
     def expr(self, o):
@@ -11,9 +12,9 @@ class IndexPushDown(MultiFunction):
 
     def indexed(self, o):
         expr, idx = o.ufl_operands
-        if isinstance(expr, Sum):
-            terms = [Indexed(term, idx) for term in get_operands(expr)]
-            return construct_binary_operator(terms, Sum)
+        if isinstance(expr, uc.Sum):
+            terms = [uc.Indexed(term, idx) for term in get_operands(expr)]
+            return construct_binary_operator(terms, uc.Sum)
         else:
             # This is a normal indexed, we treat it as any other.
             return self.expr(o)
@@ -21,4 +22,10 @@ class IndexPushDown(MultiFunction):
 
 @ufl_transformation(name="index_pushdown")
 def pushdown_indexed(e):
+    """
+    Removes the following antipattern from UFL expressions:
+    (a+b)[i] -> a[i] + b[i]
+    If similar antipatterns arise with a node other than sum,
+    add the corresponding handlers here.
+    """
     return IndexPushDown()(e)
-- 
GitLab