From 405b7dfed060831ed911c79188f2fb924d7fb280 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 6 Sep 2017 16:56:25 +0200 Subject: [PATCH] Fix sympy simplification and do it by default --- python/dune/perftool/sympy.py | 36 +++++++++++++++++++++-------- python/dune/perftool/ufl/visitor.py | 2 ++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/python/dune/perftool/sympy.py b/python/dune/perftool/sympy.py index 419fccd3..f4adfa44 100644 --- a/python/dune/perftool/sympy.py +++ b/python/dune/perftool/sympy.py @@ -1,29 +1,45 @@ from __future__ import absolute_import +from dune.perftool.generation import get_global_context_value from pymbolic.interop.sympy import SympyToPymbolicMapper, PymbolicToSympyMapper import pymbolic.primitives as prim +import loopy as lp import sympy as sp - + class MyPymbolicToSympyMapper(PymbolicToSympyMapper): def map_subscript(self, expr): - indices = expr.index if isinstance(expr.index, tuple) else (expr.index,) - return sp.Symbol("{}${}".format(expr.aggregate, "$".join(str(i) for i in indices))) + return sp.tensor.indexed.Indexed( + self.rec(expr.aggregate), + *tuple(self.rec(i) for i in expr.index_tuple) + ) + + def map_tagged_variable(self, expr): + return sp.Symbol("{}${}".format(expr.name, expr.tag)) class MySympyToPymbolicMapper(SympyToPymbolicMapper): + def map_Indexed(self, expr): + return prim.Subscript( + self.rec(expr.args[0].args[0]), + tuple(self.rec(i) for i in expr.args[1:]) + ) + def map_Symbol(self, expr): s = expr.name.split('$') r = prim.Variable(s[0]) if len(s) == 1: - return r - if len(s) == 2: - return prim.Subscript(r, prim.Variable(s[1])) + return prim.Variable(s[0]) else: - return prim.Subscript(r, tuple(prim.Variable(i) for i in s[1:])) + return lp.symbolic.TaggedVariable(s[0], s[1]) def simplify_pymbolic_expression(e): - sympyexpr = MyPymbolicToSympyMapper()(e) - simplified = sp.simplify(sympyexpr) - return MySympyToPymbolicMapper()(simplified) + if get_global_context_value("dry_run"): + # If we are on the dry run, we skip this because our expression + # may involve nodes that have no sympy equivalent (SumfactKernel...) + return e + else: + sympyexpr = MyPymbolicToSympyMapper()(e) + simplified = sp.simplify(sympyexpr) + return MySympyToPymbolicMapper()(simplified) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 7e6187e0..ff02efb8 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -55,6 +55,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.current_info = info expr = self._call(o, False) if expr != 0: + from dune.perftool.sympy import simplify_pymbolic_expression + expr = simplify_pymbolic_expression(expr) self.interface.generate_accumulation_instruction(expr, self) def _call(self, o, do_predicates): -- GitLab