From e43a20db66934e6b90be21395dd9cf21b9703253 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 11 Oct 2016 15:39:37 +0200 Subject: [PATCH] Add identity propagation transformation --- .../transformations/identitypropagation.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 python/dune/perftool/ufl/transformations/identitypropagation.py diff --git a/python/dune/perftool/ufl/transformations/identitypropagation.py b/python/dune/perftool/ufl/transformations/identitypropagation.py new file mode 100644 index 00000000..2e7bbab5 --- /dev/null +++ b/python/dune/perftool/ufl/transformations/identitypropagation.py @@ -0,0 +1,71 @@ +""" +A transformation to help the form splitting algorithm split +vector and tensor expressions. In a nutshell does: +\sum_i f(i)I(i,k) => f(k) +""" + +from dune.perftool.ufl.transformations import ufl_transformation + +from ufl.algorithms import MultiFunction +from ufl.classes import Identity, Index, IntValue, MultiIndex + + +class GetIndexMap(MultiFunction): + call = MultiFunction.__call__ + + def __call__(self, o): + self.free_index = Index(o.ufl_free_indices[0]) + self.replacemap = {} + self.call(o) + return self.replacemap + + def expr(self, o): + for op in o.ufl_operands: + self.call(op) + + def indexed(self, o): + op, i = o.ufl_operands + if isinstance(op, Identity): + assert(len(i) == 2) + assert(self.free_index in i) + ind, = set(i) - {self.free_index} + self.replacemap[ind] = self.free_index + else: + self.call(op) + + +class IdentityPropagation(MultiFunction): + call = MultiFunction.__call__ + + def __call__(self, expr): + self.replacemap = GetIndexMap()(expr) + return self.call(expr) + + def expr(self, o): + return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands)) + + def indexed(self, o): + op, i = o.ufl_operands + if isinstance(op, Identity): + return IntValue(1) + else: + return self.reuse_if_untouched(o, self.call(op), self.call(i)) + + def multi_index(self, o): + return MultiIndex(tuple(self.replacemap.get(i, i) for i in o)) + + def index_sum(self, o): + op, i = o.ufl_operands + + if i[0] in self.replacemap: + return self.call(op) + else: + return self.reuse_if_untouched(o, self.call(op), self.call(i)) + + +@ufl_transformation(name='identity') +def identity_propagation(expr): + if expr.ufl_free_indices: + return IdentityPropagation()(expr) + else: + return expr -- GitLab