Skip to content
Snippets Groups Projects
Commit e43a20db authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add identity propagation transformation

parent fa30c1c6
No related branches found
No related tags found
No related merge requests found
"""
A transformation to help the form splitting algorithm split
vector and tensor expressions. In a nutshell does:
\sum_i f(i)I(i,k) => f(k)
"""
from dune.perftool.ufl.transformations import ufl_transformation
from ufl.algorithms import MultiFunction
from ufl.classes import Identity, Index, IntValue, MultiIndex
class GetIndexMap(MultiFunction):
call = MultiFunction.__call__
def __call__(self, o):
self.free_index = Index(o.ufl_free_indices[0])
self.replacemap = {}
self.call(o)
return self.replacemap
def expr(self, o):
for op in o.ufl_operands:
self.call(op)
def indexed(self, o):
op, i = o.ufl_operands
if isinstance(op, Identity):
assert(len(i) == 2)
assert(self.free_index in i)
ind, = set(i) - {self.free_index}
self.replacemap[ind] = self.free_index
else:
self.call(op)
class IdentityPropagation(MultiFunction):
call = MultiFunction.__call__
def __call__(self, expr):
self.replacemap = GetIndexMap()(expr)
return self.call(expr)
def expr(self, o):
return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
def indexed(self, o):
op, i = o.ufl_operands
if isinstance(op, Identity):
return IntValue(1)
else:
return self.reuse_if_untouched(o, self.call(op), self.call(i))
def multi_index(self, o):
return MultiIndex(tuple(self.replacemap.get(i, i) for i in o))
def index_sum(self, o):
op, i = o.ufl_operands
if i[0] in self.replacemap:
return self.call(op)
else:
return self.reuse_if_untouched(o, self.call(op), self.call(i))
@ufl_transformation(name='identity')
def identity_propagation(expr):
if expr.ufl_free_indices:
return IdentityPropagation()(expr)
else:
return expr
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment