diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py
index eccc0203c5ced563fa2ddd710811c211e06af63b..1dd7139d880d947cb8da33a26630cc008ce514f8 100644
--- a/python/dune/perftool/ufl/transformations/indexpushdown.py
+++ b/python/dune/perftool/ufl/transformations/indexpushdown.py
@@ -15,6 +15,11 @@ class IndexPushDown(MultiFunction):
         if isinstance(expr, uc.Sum):
             terms = [uc.Indexed(self(term), idx) for term in get_operands(expr)]
             return construct_binary_operator(terms, uc.Sum)
+        elif isinstance(expr, uc.Conditional):
+            return uc.Conditional(expr.ufl_operands[0],
+                                  uc.Indexed(self(expr.ufl_operands[1]), idx),
+                                  uc.Indexed(self(expr.ufl_operands[2]), idx)
+                                  )
         else:
             # This is a normal indexed, we treat it as any other.
             return self.expr(o)
@@ -23,9 +28,11 @@ class IndexPushDown(MultiFunction):
 @ufl_transformation(name="index_pushdown")
 def pushdown_indexed(e):
     """
-    Removes the following antipattern from UFL expressions:
-    (a+b)[i] -> a[i] + b[i]
-    If similar antipatterns arise with a node other than sum,
+    Removes the following antipatterns from UFL expressions:
+    * (a+b)[i] -> a[i] + b[i]
+    * (a ? b : c)[i] -> a ? b[i] : c[i]
+
+    If similar antipatterns arise with further nodes,
     add the corresponding handlers here.
     """
     return IndexPushDown()(e)
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 9128a295ec8c9d92a46167399029ef6dd713a37b..ae1ccdae1a3f686c3fb2e43e8b82252980657d49 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -344,22 +344,21 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
     #
 
     def conditional(self, o):
-        condition = self.call(o.ufl_operands[0])
-        indices = self.indices
-
-        op1 = self.call(o.ufl_operands[1])
-        # Restore indexing information for the second branch
-        self.indices = indices
-        op2 = self.call(o.ufl_operands[2])
+        cond = self.call(o.ufl_operands[0])
 
+        # Try to evaluate the condition at code generation time
         try:
-            evaluated = eval(str(condition))
-            if evaluated:
-                return op1
-            else:
-                return op2
+            evaluated = eval(str(cond))
         except:
-            return prim.If(condition, op1, op2)
+            return prim.If(self.call(o.ufl_operands[0]),
+                           self.call(o.ufl_operands[1]),
+                           self.call(o.ufl_operands[2]))
+
+        # User code generation time evaluation
+        if evaluated:
+            return self.call(o.ufl_operands[1])
+        else:
+            return self.call(o.ufl_operands[2])
 
     def eq(self, o):
         return prim.Comparison(self.call(o.ufl_operands[0]),