From 875bb3e4e90204cc7a4958876b2a024fd3bf9bf4 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 12 Apr 2016 13:37:24 +0200 Subject: [PATCH] Add a shape determination algorithm --- python/dune/perftool/ufl/shape.py | 37 +++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 python/dune/perftool/ufl/shape.py diff --git a/python/dune/perftool/ufl/shape.py b/python/dune/perftool/ufl/shape.py new file mode 100644 index 00000000..bc9e5bdb --- /dev/null +++ b/python/dune/perftool/ufl/shape.py @@ -0,0 +1,37 @@ +""" An algorithm to determine the shape aka the loop domain for a given index """ + +from ufl.algorithms import MultiFunction + + +class ShapeDetermination(MultiFunction): + def __init__(self, index): + MultiFunction.__init__(self) + self.index = index + + def expr(self, o): + ret = tuple(self(op) for op in o.ufl_operands) + asset = set(ret) + if len(asset) == 0: + return None + if len(asset) == 1: + # All determined shapes are equal. We just use the result. + return ret[0] + if len(asset) == 2: + try: + # Two shapes determined, one might be "None" so we use the other. + asset.remove(None) + return ret[0] + except KeyError: + pass + raise AssertionError("I had trouble determining the shape of an expression!") + + def indexed(self, o): + try: + position = o.ufl_operands[1].indices().index(self.index) + return o.ufl_operands[0].ufl_shape[position] + except ValueError: + return self(o.ufl_operands[0]) + + +def determine_shape(expr, i): + return ShapeDetermination(i)(expr) -- GitLab