diff --git a/python/dune/perftool/ufl/transformations/indexpushdown.py b/python/dune/perftool/ufl/transformations/indexpushdown.py index 1dd7139d880d947cb8da33a26630cc008ce514f8..73b8d73a670f66517060ac61338a683da275d90e 100644 --- a/python/dune/perftool/ufl/transformations/indexpushdown.py +++ b/python/dune/perftool/ufl/transformations/indexpushdown.py @@ -16,9 +16,9 @@ class IndexPushDown(MultiFunction): terms = [uc.Indexed(self(term), idx) for term in get_operands(expr)] return construct_binary_operator(terms, uc.Sum) elif isinstance(expr, uc.Conditional): - return uc.Conditional(expr.ufl_operands[0], - uc.Indexed(self(expr.ufl_operands[1]), idx), - uc.Indexed(self(expr.ufl_operands[2]), idx) + return uc.Conditional(self(expr.ufl_operands[0]), + self(uc.Indexed(expr.ufl_operands[1], idx)), + self(uc.Indexed(expr.ufl_operands[2], idx)) ) else: # This is a normal indexed, we treat it as any other. diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index 071102aea95bb5bb83ed44304a17e57e5ffcec73..aa9d7714a2ae60da6fd37911eced5f15af259d46 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -258,23 +258,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # Those handlers would be valid in any code going from UFL to pymbolic # - def _call_with_preserved_index(self, ops): - # This can be used if the indexing state needs to be reused for several - # branches. This happens so far in the following scenarios: - # * Sums of tensors - # * Conditionals - ind = self.indices - after_ind = [] - ret = [] - for o in ops: - self.indices = ind - ret.append(self.call(o)) - after_ind.append(self.indices) - - assert(len(set(after_ind)) == 1) - self.indices = after_ind[0] - return tuple(ret) - def product(self, o): return prim.flattened_product(tuple(self.call(op) for op in o.ufl_operands)) @@ -288,7 +271,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return prim.quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1])) def sum(self, o): - return prim.flattened_sum(self._call_with_preserved_index(o.ufl_operands)) + return prim.flattened_sum(tuple(self.call(op) for op in o.ufl_operands)) def zero(self, o): # UFL has Zeroes with shape. We ignore those indices. @@ -360,17 +343,15 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # def conditional(self, o): - ind = self.indices - self.indices = None cond = self.call(o.ufl_operands[0]) - self.indices = ind # Try to evaluate the condition at code generation time try: evaluated = eval(str(cond)) except: - branches = self._call_with_preserved_index(o.ufl_operands[1:]) - return prim.If(cond, *branches) + return prim.If(cond, + self.call(o.ufl_operands[1]), + self.call(o.ufl_operands[2])) # User code generation time evaluation if evaluated: