Skip to content
Snippets Groups Projects
Commit c3519e39 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix more sympy glitches: Floor Division and mod!

parent 6a678106
No related branches found
No related tags found
No related merge requests found
......@@ -43,11 +43,25 @@ class MyPymbolicToSympyMapper(PymbolicToSympyMapper):
class MySympyToPymbolicMapper(SympyToPymbolicMapper):
def map_floor(self, expr):
# Try finding patterns arising from FloorDiv
assert isinstance(expr.args[0], sp.Mul)
margs = expr.args[0].args
if isinstance(margs[1], sp.Pow) and int(margs[1].args[1]) == -1:
return prim.FloorDiv(self.rec(margs[0]), self.rec(margs[1].args[0]))
elif isinstance(margs[0], sp.Rational) and margs[0].as_numer_denom()[0] == 1:
return prim.FloorDiv(self.rec(margs[1]), self.rec(margs[0].as_numer_denom()[1]))
else:
raise NotImplementedError("Congratulations, sympy.floor showed you its deficits!")
def map_Indexed(self, expr):
return prim.Subscript(self.rec(expr.args[0].args[0]),
tuple(self.rec(i) for i in expr.args[1:])
)
def map_Mod(self, expr):
return prim.Remainder(self.rec(expr.args[0]), self.rec(expr.args[1]))
def map_Symbol(self, expr):
s = expr.name.split('$')
r = prim.Variable(s[0])
......@@ -85,6 +99,9 @@ def simplify_pymbolic_expression(e):
# may involve nodes that have no sympy equivalent (SumfactKernel...)
return e
else:
sympyexpr = MyPymbolicToSympyMapper()(e)
forward = MyPymbolicToSympyMapper()
backward = MySympyToPymbolicMapper()
sympyexpr = forward(e)
simplified = sp.simplify(sympyexpr)
return MySympyToPymbolicMapper()(simplified)
return backward(simplified)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment