diff --git a/python/dune/perftool/sympy.py b/python/dune/perftool/sympy.py index 419fccd3ed01849e29a6aea6cdaf2ef05a8280e8..f4adfa442c3663d42d6a33e4930755ca27e03ba3 100644 --- a/python/dune/perftool/sympy.py +++ b/python/dune/perftool/sympy.py @@ -1,29 +1,45 @@ from __future__ import absolute_import +from dune.perftool.generation import get_global_context_value from pymbolic.interop.sympy import SympyToPymbolicMapper, PymbolicToSympyMapper import pymbolic.primitives as prim +import loopy as lp import sympy as sp - + class MyPymbolicToSympyMapper(PymbolicToSympyMapper): def map_subscript(self, expr): - indices = expr.index if isinstance(expr.index, tuple) else (expr.index,) - return sp.Symbol("{}${}".format(expr.aggregate, "$".join(str(i) for i in indices))) + return sp.tensor.indexed.Indexed( + self.rec(expr.aggregate), + *tuple(self.rec(i) for i in expr.index_tuple) + ) + + def map_tagged_variable(self, expr): + return sp.Symbol("{}${}".format(expr.name, expr.tag)) class MySympyToPymbolicMapper(SympyToPymbolicMapper): + def map_Indexed(self, expr): + return prim.Subscript( + self.rec(expr.args[0].args[0]), + tuple(self.rec(i) for i in expr.args[1:]) + ) + def map_Symbol(self, expr): s = expr.name.split('$') r = prim.Variable(s[0]) if len(s) == 1: - return r - if len(s) == 2: - return prim.Subscript(r, prim.Variable(s[1])) + return prim.Variable(s[0]) else: - return prim.Subscript(r, tuple(prim.Variable(i) for i in s[1:])) + return lp.symbolic.TaggedVariable(s[0], s[1]) def simplify_pymbolic_expression(e): - sympyexpr = MyPymbolicToSympyMapper()(e) - simplified = sp.simplify(sympyexpr) - return MySympyToPymbolicMapper()(simplified) + if get_global_context_value("dry_run"): + # If we are on the dry run, we skip this because our expression + # may involve nodes that have no sympy equivalent (SumfactKernel...) + return e + else: + sympyexpr = MyPymbolicToSympyMapper()(e) + simplified = sp.simplify(sympyexpr) + return MySympyToPymbolicMapper()(simplified) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 7e6187e0bf30b79a6d770a80c44492ee2d6835be..ff02efb8ead28e0595ad3234b9433d4a21dd3e65 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -55,6 +55,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.current_info = info expr = self._call(o, False) if expr != 0: + from dune.perftool.sympy import simplify_pymbolic_expression + expr = simplify_pymbolic_expression(expr) self.interface.generate_accumulation_instruction(expr, self) def _call(self, o, do_predicates):