From d020431bcd40a2d4eadb7a8221368d0bec72baf1 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Fri, 25 Aug 2017 10:58:31 +0200 Subject: [PATCH] Evaluate some functions at code generation time --- python/dune/perftool/ufl/visitor.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index e4983773..2027c303 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -276,17 +276,23 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.indices = None return 0 + def _evaluate_function(self, python_func, c_func, val): + if isinstance(val, (float, int)): + return python_func(val) + else: + return prim.Call(prim.Variable(c_func), (val,)) + def abs(self, o): if isinstance(o.ufl_operands[0], JacobianDeterminant): return self.call(o.ufl_operands[0]) else: - return Call(Variable('abs'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(abs, "abs", self.call(o.ufl_operands[0])) def exp(self, o): - return Call(Variable('exp'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(exp, "exp", self.call(o.ufl_operands[0])) def sqrt(self, o): - return Call(Variable('sqrt'), (self.call(o.ufl_operands[0]),)) + return self._evaluate_function(sqrt, "sqrt", self.call(o.ufl_operands[0])) def power(self, o): from ufl.constantvalue import IntValue @@ -333,7 +339,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): self.indices = indices op2 = self.call(o.ufl_operands[2]) - return prim.If(condition, op1, op2) + try: + evaluated = eval(str(condition)) + if evaluated: + return op1 + else: + return op2 + except: + return prim.If(condition, op1, op2) def eq(self, o): return prim.Comparison(self.call(o.ufl_operands[0]), -- GitLab