From 875bb3e4e90204cc7a4958876b2a024fd3bf9bf4 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 12 Apr 2016 13:37:24 +0200
Subject: [PATCH] Add a shape determination algorithm

---
 python/dune/perftool/ufl/shape.py | 37 +++++++++++++++++++++++++++++++
 1 file changed, 37 insertions(+)
 create mode 100644 python/dune/perftool/ufl/shape.py

diff --git a/python/dune/perftool/ufl/shape.py b/python/dune/perftool/ufl/shape.py
new file mode 100644
index 00000000..bc9e5bdb
--- /dev/null
+++ b/python/dune/perftool/ufl/shape.py
@@ -0,0 +1,37 @@
+""" An algorithm to determine the shape aka the loop domain for a given index """
+
+from ufl.algorithms import MultiFunction
+
+
+class ShapeDetermination(MultiFunction):
+    def __init__(self, index):
+        MultiFunction.__init__(self)
+        self.index = index
+
+    def expr(self, o):
+        ret = tuple(self(op) for op in o.ufl_operands)
+        asset = set(ret)
+        if len(asset) == 0:
+            return None
+        if len(asset) == 1:
+            # All determined shapes are equal. We just use the result.
+            return ret[0]
+        if len(asset) == 2:
+            try:
+                # Two shapes determined, one might be "None" so we use the other.
+                asset.remove(None)
+                return ret[0]
+            except KeyError:
+                pass
+        raise AssertionError("I had trouble determining the shape of an expression!")
+
+    def indexed(self, o):
+        try:
+            position = o.ufl_operands[1].indices().index(self.index)
+            return o.ufl_operands[0].ufl_shape[position]
+        except ValueError:
+            return self(o.ufl_operands[0])
+
+
+def determine_shape(expr, i):
+    return ShapeDetermination(i)(expr)
-- 
GitLab