diff --git a/python/dune/perftool/sympy.py b/python/dune/perftool/sympy.py index 853751b9cc7c82831025392d9087376ee2f60cb5..4999ad8aa90f6e69c4f122db9c4772b7c7cb85f0 100644 --- a/python/dune/perftool/sympy.py +++ b/python/dune/perftool/sympy.py @@ -43,11 +43,25 @@ class MyPymbolicToSympyMapper(PymbolicToSympyMapper): class MySympyToPymbolicMapper(SympyToPymbolicMapper): + def map_floor(self, expr): + # Try finding patterns arising from FloorDiv + assert isinstance(expr.args[0], sp.Mul) + margs = expr.args[0].args + if isinstance(margs[1], sp.Pow) and int(margs[1].args[1]) == -1: + return prim.FloorDiv(self.rec(margs[0]), self.rec(margs[1].args[0])) + elif isinstance(margs[0], sp.Rational) and margs[0].as_numer_denom()[0] == 1: + return prim.FloorDiv(self.rec(margs[1]), self.rec(margs[0].as_numer_denom()[1])) + else: + raise NotImplementedError("Congratulations, sympy.floor showed you its deficits!") + def map_Indexed(self, expr): return prim.Subscript(self.rec(expr.args[0].args[0]), tuple(self.rec(i) for i in expr.args[1:]) ) + def map_Mod(self, expr): + return prim.Remainder(self.rec(expr.args[0]), self.rec(expr.args[1])) + def map_Symbol(self, expr): s = expr.name.split('$') r = prim.Variable(s[0]) @@ -85,6 +99,9 @@ def simplify_pymbolic_expression(e): # may involve nodes that have no sympy equivalent (SumfactKernel...) return e else: - sympyexpr = MyPymbolicToSympyMapper()(e) + forward = MyPymbolicToSympyMapper() + backward = MySympyToPymbolicMapper() + + sympyexpr = forward(e) simplified = sp.simplify(sympyexpr) - return MySympyToPymbolicMapper()(simplified) + return backward(simplified)