diff --git a/python/dune/perftool/ufl/transformations/identitypropagation.py b/python/dune/perftool/ufl/transformations/identitypropagation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e7bbab52027cefcd08bd7ac5e80d27f192184c5
--- /dev/null
+++ b/python/dune/perftool/ufl/transformations/identitypropagation.py
@@ -0,0 +1,71 @@
+"""
+A transformation to help the form splitting algorithm split
+vector and tensor expressions. In a nutshell does:
+\sum_i f(i)I(i,k) => f(k)
+"""
+
+from dune.perftool.ufl.transformations import ufl_transformation
+
+from ufl.algorithms import MultiFunction
+from ufl.classes import Identity, Index, IntValue, MultiIndex
+
+
+class GetIndexMap(MultiFunction):
+    call = MultiFunction.__call__
+
+    def __call__(self, o):
+        self.free_index = Index(o.ufl_free_indices[0])
+        self.replacemap = {}
+        self.call(o)
+        return self.replacemap
+
+    def expr(self, o):
+        for op in o.ufl_operands:
+            self.call(op)
+
+    def indexed(self, o):
+        op, i = o.ufl_operands
+        if isinstance(op, Identity):
+            assert(len(i) == 2)
+            assert(self.free_index in i)
+            ind, = set(i) - {self.free_index}
+            self.replacemap[ind] = self.free_index
+        else:
+            self.call(op)
+
+
+class IdentityPropagation(MultiFunction):
+    call = MultiFunction.__call__
+
+    def __call__(self, expr):
+        self.replacemap = GetIndexMap()(expr)
+        return self.call(expr)
+
+    def expr(self, o):
+        return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
+
+    def indexed(self, o):
+        op, i = o.ufl_operands
+        if isinstance(op, Identity):
+            return IntValue(1)
+        else:
+            return self.reuse_if_untouched(o, self.call(op), self.call(i))
+
+    def multi_index(self, o):
+        return MultiIndex(tuple(self.replacemap.get(i, i) for i in o))
+
+    def index_sum(self, o):
+        op, i = o.ufl_operands
+
+        if i[0] in self.replacemap:
+            return self.call(op)
+        else:
+            return self.reuse_if_untouched(o, self.call(op), self.call(i))
+
+
+@ufl_transformation(name='identity')
+def identity_propagation(expr):
+    if expr.ufl_free_indices:
+        return IdentityPropagation()(expr)
+    else:
+        return expr