From d020431bcd40a2d4eadb7a8221368d0bec72baf1 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 25 Aug 2017 10:58:31 +0200
Subject: [PATCH] Evaluate some functions at code generation time

---
 python/dune/perftool/ufl/visitor.py | 21 +++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index e4983773..2027c303 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -276,17 +276,23 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         self.indices = None
         return 0
 
+    def _evaluate_function(self, python_func, c_func, val):
+        if isinstance(val, (float, int)):
+            return python_func(val)
+        else:
+            return prim.Call(prim.Variable(c_func), (val,))
+
     def abs(self, o):
         if isinstance(o.ufl_operands[0], JacobianDeterminant):
             return self.call(o.ufl_operands[0])
         else:
-            return Call(Variable('abs'), (self.call(o.ufl_operands[0]),))
+            return self._evaluate_function(abs, "abs", self.call(o.ufl_operands[0]))
 
     def exp(self, o):
-        return Call(Variable('exp'), (self.call(o.ufl_operands[0]),))
+        return self._evaluate_function(exp, "exp", self.call(o.ufl_operands[0]))
 
     def sqrt(self, o):
-        return Call(Variable('sqrt'), (self.call(o.ufl_operands[0]),))
+        return self._evaluate_function(sqrt, "sqrt", self.call(o.ufl_operands[0]))
 
     def power(self, o):
         from ufl.constantvalue import IntValue
@@ -333,7 +339,14 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         self.indices = indices
         op2 = self.call(o.ufl_operands[2])
 
-        return prim.If(condition, op1, op2)
+        try:
+            evaluated = eval(str(condition))
+            if evaluated:
+                return op1
+            else:
+                return op2
+        except:
+            return prim.If(condition, op1, op2)
 
     def eq(self, o):
         return prim.Comparison(self.call(o.ufl_operands[0]),
-- 
GitLab