Skip to content
Snippets Groups Projects
Commit 5081753a authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Adjust visitor to allow sums/conditionals of expressions of arbitrary shape

The visitor state regarding indices needs to be correctly reset
for child iteration in these cases.
parent 11871f97
No related branches found
No related tags found
No related merge requests found
...@@ -248,6 +248,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -248,6 +248,8 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
if all(isinstance(i, int) for i in self.indices): if all(isinstance(i, int) for i in self.indices):
index = self.indices[0] index = self.indices[0]
self.indices = self.indices[1:] self.indices = self.indices[1:]
if len(self.indices) == 0:
self.indices = None
return self.call(o.ufl_operands[index]) return self.call(o.ufl_operands[index])
else: else:
return self.interface.pymbolic_list_tensor(o) return self.interface.pymbolic_list_tensor(o)
...@@ -274,6 +276,26 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -274,6 +276,26 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# Those handlers would be valid in any code going from UFL to pymbolic # Those handlers would be valid in any code going from UFL to pymbolic
# #
def _call_with_preserved_index(self, ops):
# This can be used if the indexing state needs to be reused for several
# branches. This happens so far in the following scenarios:
# * Sums of tensors
# * Conditionals
ind = self.indices
after_ind = []
ret = []
for o in ops:
self.indices = ind
ret.append(self.call(o))
after_ind.append(self.indices)
try:
assert(len(set(after_ind)) == 1)
except:
from pudb import set_trace; set_trace()
self.indices = after_ind[0]
return tuple(ret)
def product(self, o): def product(self, o):
return prim.flattened_product(tuple(self.call(op) for op in o.ufl_operands)) return prim.flattened_product(tuple(self.call(op) for op in o.ufl_operands))
...@@ -287,7 +309,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -287,7 +309,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
return prim.quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1])) return prim.quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1]))
def sum(self, o): def sum(self, o):
return prim.flattened_sum(tuple(self.call(op) for op in o.ufl_operands)) return prim.flattened_sum(self._call_with_preserved_index(o.ufl_operands))
def zero(self, o): def zero(self, o):
# UFL has Zeroes with shape. We ignore those indices. # UFL has Zeroes with shape. We ignore those indices.
...@@ -359,15 +381,17 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): ...@@ -359,15 +381,17 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
# #
def conditional(self, o): def conditional(self, o):
ind = self.indices
self.indices = None
cond = self.call(o.ufl_operands[0]) cond = self.call(o.ufl_operands[0])
self.indices = ind
# Try to evaluate the condition at code generation time # Try to evaluate the condition at code generation time
try: try:
evaluated = eval(str(cond)) evaluated = eval(str(cond))
except: except:
return prim.If(self.call(o.ufl_operands[0]), branches = self._call_with_preserved_index(o.ufl_operands[1:])
self.call(o.ufl_operands[1]), return prim.If(cond, *branches)
self.call(o.ufl_operands[2]))
# User code generation time evaluation # User code generation time evaluation
if evaluated: if evaluated:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment