Skip to content
Snippets Groups Projects
Commit 405b7dfe authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix sympy simplification and do it by default

parent bec9c8f4
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import
from dune.perftool.generation import get_global_context_value
from pymbolic.interop.sympy import SympyToPymbolicMapper, PymbolicToSympyMapper
import pymbolic.primitives as prim
import loopy as lp
import sympy as sp
class MyPymbolicToSympyMapper(PymbolicToSympyMapper):
def map_subscript(self, expr):
indices = expr.index if isinstance(expr.index, tuple) else (expr.index,)
return sp.Symbol("{}${}".format(expr.aggregate, "$".join(str(i) for i in indices)))
return sp.tensor.indexed.Indexed(
self.rec(expr.aggregate),
*tuple(self.rec(i) for i in expr.index_tuple)
)
def map_tagged_variable(self, expr):
return sp.Symbol("{}${}".format(expr.name, expr.tag))
class MySympyToPymbolicMapper(SympyToPymbolicMapper):
def map_Indexed(self, expr):
return prim.Subscript(
self.rec(expr.args[0].args[0]),
tuple(self.rec(i) for i in expr.args[1:])
)
def map_Symbol(self, expr):
s = expr.name.split('$')
r = prim.Variable(s[0])
if len(s) == 1:
return r
if len(s) == 2:
return prim.Subscript(r, prim.Variable(s[1]))
return prim.Variable(s[0])
else:
return prim.Subscript(r, tuple(prim.Variable(i) for i in s[1:]))
return lp.symbolic.TaggedVariable(s[0], s[1])
def simplify_pymbolic_expression(e):
sympyexpr = MyPymbolicToSympyMapper()(e)
simplified = sp.simplify(sympyexpr)
return MySympyToPymbolicMapper()(simplified)
if get_global_context_value("dry_run"):
# If we are on the dry run, we skip this because our expression
# may involve nodes that have no sympy equivalent (SumfactKernel...)
return e
else:
sympyexpr = MyPymbolicToSympyMapper()(e)
simplified = sp.simplify(sympyexpr)
return MySympyToPymbolicMapper()(simplified)
......@@ -55,6 +55,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
self.current_info = info
expr = self._call(o, False)
if expr != 0:
from dune.perftool.sympy import simplify_pymbolic_expression
expr = simplify_pymbolic_expression(expr)
self.interface.generate_accumulation_instruction(expr, self)
def _call(self, o, do_predicates):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment