diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index fcf2ee7728e96a57ebd94ada15966c992ff8120f..6c0ae101b1e85bef97b6ab8c4f315529f3a87550 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -248,6 +248,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): if all(isinstance(i, int) for i in self.indices): index = self.indices[0] self.indices = self.indices[1:] + if len(self.indices) == 0: + self.indices = None return self.call(o.ufl_operands[index]) else: return self.interface.pymbolic_list_tensor(o) @@ -274,6 +276,26 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # Those handlers would be valid in any code going from UFL to pymbolic # + def _call_with_preserved_index(self, ops): + # This can be used if the indexing state needs to be reused for several + # branches. This happens so far in the following scenarios: + # * Sums of tensors + # * Conditionals + ind = self.indices + after_ind = [] + ret = [] + for o in ops: + self.indices = ind + ret.append(self.call(o)) + after_ind.append(self.indices) + + try: + assert(len(set(after_ind)) == 1) + except: + from pudb import set_trace; set_trace() + self.indices = after_ind[0] + return tuple(ret) + def product(self, o): return prim.flattened_product(tuple(self.call(op) for op in o.ufl_operands)) @@ -287,7 +309,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return prim.quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1])) def sum(self, o): - return prim.flattened_sum(tuple(self.call(op) for op in o.ufl_operands)) + return prim.flattened_sum(self._call_with_preserved_index(o.ufl_operands)) def zero(self, o): # UFL has Zeroes with shape. We ignore those indices. @@ -359,15 +381,17 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # def conditional(self, o): + ind = self.indices + self.indices = None cond = self.call(o.ufl_operands[0]) + self.indices = ind # Try to evaluate the condition at code generation time try: evaluated = eval(str(cond)) except: - return prim.If(self.call(o.ufl_operands[0]), - self.call(o.ufl_operands[1]), - self.call(o.ufl_operands[2])) + branches = self._call_with_preserved_index(o.ufl_operands[1:]) + return prim.If(cond, *branches) # User code generation time evaluation if evaluated: