From 405b7dfed060831ed911c79188f2fb924d7fb280 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 6 Sep 2017 16:56:25 +0200
Subject: [PATCH] Fix sympy simplification and do it by default

---
 python/dune/perftool/sympy.py       | 36 +++++++++++++++++++++--------
 python/dune/perftool/ufl/visitor.py |  2 ++
 2 files changed, 28 insertions(+), 10 deletions(-)

diff --git a/python/dune/perftool/sympy.py b/python/dune/perftool/sympy.py
index 419fccd3..f4adfa44 100644
--- a/python/dune/perftool/sympy.py
+++ b/python/dune/perftool/sympy.py
@@ -1,29 +1,45 @@
 from __future__ import absolute_import
+from dune.perftool.generation import get_global_context_value
 from pymbolic.interop.sympy import SympyToPymbolicMapper, PymbolicToSympyMapper
 
 import pymbolic.primitives as prim
+import loopy as lp
 import sympy as sp
 
-
+ 
 class MyPymbolicToSympyMapper(PymbolicToSympyMapper):
     def map_subscript(self, expr):
-        indices = expr.index if isinstance(expr.index, tuple) else (expr.index,)
-        return sp.Symbol("{}${}".format(expr.aggregate, "$".join(str(i) for i in indices)))
+        return sp.tensor.indexed.Indexed(
+            self.rec(expr.aggregate),
+            *tuple(self.rec(i) for i in expr.index_tuple)
+            )
+
+    def map_tagged_variable(self, expr):
+        return sp.Symbol("{}${}".format(expr.name, expr.tag))
 
 
 class MySympyToPymbolicMapper(SympyToPymbolicMapper):
+    def map_Indexed(self, expr):
+        return prim.Subscript(
+            self.rec(expr.args[0].args[0]),
+            tuple(self.rec(i) for i in expr.args[1:])
+            )
+
     def map_Symbol(self, expr):
         s = expr.name.split('$')
         r = prim.Variable(s[0])
         if len(s) == 1:
-            return r
-        if len(s) == 2:
-            return prim.Subscript(r, prim.Variable(s[1]))
+            return prim.Variable(s[0])
         else:
-            return prim.Subscript(r, tuple(prim.Variable(i) for i in s[1:]))
+            return lp.symbolic.TaggedVariable(s[0], s[1])
 
 
 def simplify_pymbolic_expression(e):
-    sympyexpr = MyPymbolicToSympyMapper()(e)
-    simplified = sp.simplify(sympyexpr)
-    return MySympyToPymbolicMapper()(simplified)
+    if get_global_context_value("dry_run"):
+        # If we are on the dry run, we skip this because our expression
+        # may involve nodes that have no sympy equivalent (SumfactKernel...)
+        return e
+    else:
+        sympyexpr = MyPymbolicToSympyMapper()(e)
+        simplified = sp.simplify(sympyexpr)
+        return MySympyToPymbolicMapper()(simplified)
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 7e6187e0..ff02efb8 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -55,6 +55,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
             self.current_info = info
             expr = self._call(o, False)
             if expr != 0:
+                from dune.perftool.sympy import simplify_pymbolic_expression
+                expr = simplify_pymbolic_expression(expr)
                 self.interface.generate_accumulation_instruction(expr, self)
 
     def _call(self, o, do_predicates):
-- 
GitLab