diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py index eccc0203c5ced563fa2ddd710811c211e06af63b..1dd7139d880d947cb8da33a26630cc008ce514f8 100644 --- a/python/dune/perftool/ufl/transformations/indexpushdown.py +++ b/python/dune/perftool/ufl/transformations/indexpushdown.py @@ -15,6 +15,11 @@ class IndexPushDown(MultiFunction): if isinstance(expr, uc.Sum): terms = [uc.Indexed(self(term), idx) for term in get_operands(expr)] return construct_binary_operator(terms, uc.Sum) + elif isinstance(expr, uc.Conditional): + return uc.Conditional(expr.ufl_operands[0], + uc.Indexed(self(expr.ufl_operands[1]), idx), + uc.Indexed(self(expr.ufl_operands[2]), idx) + ) else: # This is a normal indexed, we treat it as any other. return self.expr(o) @@ -23,9 +28,11 @@ class IndexPushDown(MultiFunction): @ufl_transformation(name="index_pushdown") def pushdown_indexed(e): """ - Removes the following antipattern from UFL expressions: - (a+b)[i] -> a[i] + b[i] - If similar antipatterns arise with a node other than sum, + Removes the following antipatterns from UFL expressions: + * (a+b)[i] -> a[i] + b[i] + * (a ? b : c)[i] -> a ? b[i] : c[i] + + If similar antipatterns arise with further nodes, add the corresponding handlers here. """ return IndexPushDown()(e) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 9128a295ec8c9d92a46167399029ef6dd713a37b..ae1ccdae1a3f686c3fb2e43e8b82252980657d49 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -344,22 +344,21 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # def conditional(self, o): - condition = self.call(o.ufl_operands[0]) - indices = self.indices - - op1 = self.call(o.ufl_operands[1]) - # Restore indexing information for the second branch - self.indices = indices - op2 = self.call(o.ufl_operands[2]) + cond = self.call(o.ufl_operands[0]) + # Try to evaluate the condition at code generation time try: - evaluated = eval(str(condition)) - if evaluated: - return op1 - else: - return op2 + evaluated = eval(str(cond)) except: - return prim.If(condition, op1, op2) + return prim.If(self.call(o.ufl_operands[0]), + self.call(o.ufl_operands[1]), + self.call(o.ufl_operands[2])) + + # User code generation time evaluation + if evaluated: + return self.call(o.ufl_operands[1]) + else: + return self.call(o.ufl_operands[2]) def eq(self, o): return prim.Comparison(self.call(o.ufl_operands[0]),