diff --git a/python/dune/perftool/ufl/shape.py b/python/dune/perftool/ufl/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..bc9e5bdbf7ff35693cdba9bbb3e8272a07f056f8 --- /dev/null +++ b/python/dune/perftool/ufl/shape.py @@ -0,0 +1,37 @@ +""" An algorithm to determine the shape aka the loop domain for a given index """ + +from ufl.algorithms import MultiFunction + + +class ShapeDetermination(MultiFunction): + def __init__(self, index): + MultiFunction.__init__(self) + self.index = index + + def expr(self, o): + ret = tuple(self(op) for op in o.ufl_operands) + asset = set(ret) + if len(asset) == 0: + return None + if len(asset) == 1: + # All determined shapes are equal. We just use the result. + return ret[0] + if len(asset) == 2: + try: + # Two shapes determined, one might be "None" so we use the other. + asset.remove(None) + return ret[0] + except KeyError: + pass + raise AssertionError("I had trouble determining the shape of an expression!") + + def indexed(self, o): + try: + position = o.ufl_operands[1].indices().index(self.index) + return o.ufl_operands[0].ufl_shape[position] + except ValueError: + return self(o.ufl_operands[0]) + + +def determine_shape(expr, i): + return ShapeDetermination(i)(expr)