diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index e4983773ded3655989b72b74913b05282392780f..2027c303bf7af76cf2a94a97b7bea738ece6b80a 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -276,17 +276,23 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.indices = None return 0 + def _evaluate_function(self, python_func, c_func, val): + if isinstance(val, (float, int)): + return python_func(val) + else: + return prim.Call(prim.Variable(c_func), (val,)) + def abs(self, o): if isinstance(o.ufl_operands[0], JacobianDeterminant): return self.call(o.ufl_operands[0]) else: - return Call(Variable('abs'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(abs, "abs", self.call(o.ufl_operands[0])) def exp(self, o): - return Call(Variable('exp'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(exp, "exp", self.call(o.ufl_operands[0])) def sqrt(self, o): - return Call(Variable('sqrt'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(sqrt, "sqrt", self.call(o.ufl_operands[0])) def power(self, o): from ufl.constantvalue import IntValue @@ -333,7 +339,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.indices = indices op2 = self.call(o.ufl_operands[2]) - return prim.If(condition, op1, op2) + try: + evaluated = eval(str(condition)) + if evaluated: + return op1 + else: + return op2 + except: + return prim.If(condition, op1, op2) def eq(self, o): return prim.Comparison(self.call(o.ufl_operands[0]),