Skip to content
Snippets Groups Projects
Commit 875bb3e4 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add a shape determination algorithm

parent e1d99888
No related branches found
No related tags found
No related merge requests found
""" An algorithm to determine the shape aka the loop domain for a given index """
from ufl.algorithms import MultiFunction
class ShapeDetermination(MultiFunction):
def __init__(self, index):
MultiFunction.__init__(self)
self.index = index
def expr(self, o):
ret = tuple(self(op) for op in o.ufl_operands)
asset = set(ret)
if len(asset) == 0:
return None
if len(asset) == 1:
# All determined shapes are equal. We just use the result.
return ret[0]
if len(asset) == 2:
try:
# Two shapes determined, one might be "None" so we use the other.
asset.remove(None)
return ret[0]
except KeyError:
pass
raise AssertionError("I had trouble determining the shape of an expression!")
def indexed(self, o):
try:
position = o.ufl_operands[1].indices().index(self.index)
return o.ufl_operands[0].ufl_shape[position]
except ValueError:
return self(o.ufl_operands[0])
def determine_shape(expr, i):
return ShapeDetermination(i)(expr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment