Skip to content
Snippets Groups Projects
Commit d020431b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Evaluate some functions at code generation time

parent a535a3e0
No related branches found
No related tags found
No related merge requests found
......@@ -276,17 +276,23 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
self.indices = None
return 0
def _evaluate_function(self, python_func, c_func, val):
if isinstance(val, (float, int)):
return python_func(val)
else:
return prim.Call(prim.Variable(c_func), (val,))
def abs(self, o):
if isinstance(o.ufl_operands[0], JacobianDeterminant):
return self.call(o.ufl_operands[0])
else:
return Call(Variable('abs'), (self.call(o.ufl_operands[0]),))
return self._evaluate_function(abs, "abs", self.call(o.ufl_operands[0]))
def exp(self, o):
return Call(Variable('exp'), (self.call(o.ufl_operands[0]),))
return self._evaluate_function(exp, "exp", self.call(o.ufl_operands[0]))
def sqrt(self, o):
return Call(Variable('sqrt'), (self.call(o.ufl_operands[0]),))
return self._evaluate_function(sqrt, "sqrt", self.call(o.ufl_operands[0]))
def power(self, o):
from ufl.constantvalue import IntValue
......@@ -333,7 +339,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
self.indices = indices
op2 = self.call(o.ufl_operands[2])
return prim.If(condition, op1, op2)
try:
evaluated = eval(str(condition))
if evaluated:
return op1
else:
return op2
except:
return prim.If(condition, op1, op2)
def eq(self, o):
return prim.Comparison(self.call(o.ufl_operands[0]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment