diff --git a/python/dune/perftool/loopy/mangler.py b/python/dune/perftool/loopy/mangler.py
index 337205b3a5d285005288afed3bab845547cd5065..3d6dd2ad0b6ee9174c3218b1bd21801344c6c412 100644
--- a/python/dune/perftool/loopy/mangler.py
+++ b/python/dune/perftool/loopy/mangler.py
@@ -16,8 +16,8 @@ def using_std_statement(name):
 
 @function_mangler
 def dune_math_manglers(kernel, name, arg_dtypes):
-    dt = arg_dtypes[0]
     if name == "exp":
+        dt = arg_dtypes[0]
         using_std_statement(name)
         include_file("dune/perftool/common/vectorclass.hh", filetag="operatorfile")
         return CallMangleInfo("exp",
@@ -26,6 +26,7 @@ def dune_math_manglers(kernel, name, arg_dtypes):
                               )
 
     if name == "sqrt":
+        dt = arg_dtypes[0]
         using_std_statement(name)
         return CallMangleInfo("sqrt",
                               arg_dtypes,
@@ -33,6 +34,7 @@ def dune_math_manglers(kernel, name, arg_dtypes):
                               )
 
     if name == "max":
+        dt = arg_dtypes[0]
         using_std_statement(name)
         return CallMangleInfo("max",
                               (dt,),
@@ -40,6 +42,7 @@ def dune_math_manglers(kernel, name, arg_dtypes):
                               )
 
     if name == "min":
+        dt = arg_dtypes[0]
         using_std_statement(name)
         return CallMangleInfo("min",
                               (dt,),
diff --git a/python/dune/perftool/options.py b/python/dune/perftool/options.py
index 9f75b3cebf075ab3dc6280f0d563bf9e039b9597..ded7826e994cd51a017415825775f7a5d5d8ebcd 100644
--- a/python/dune/perftool/options.py
+++ b/python/dune/perftool/options.py
@@ -48,6 +48,7 @@ def get_form_compiler_arguments():
     parser.add_argument("--sumfact", action="store_true", help="Use sumfactorization")
     parser.add_argument("--vectorize-quad", action="store_true", help="whether to generate code with explicit vectorization")
     parser.add_argument("--vectorize-grads", action="store_true", help="whether to generate code with explicit vectorization")
+    parser.add_argument("--turn-off-diagonal-jacobian", action="store_true", help="Do not use diagonal_jacobian transformation on the ufl tree and cast result of jacobianInverseTransposed into a FieldMatrix.")
 
     # Modify the positional argument to not be a list
     args = vars(parser.parse_args())
diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py
index dcc17068303501fc01d96e842089454907342c3b..4afe9a97e1fe833af5df9edb204a4dc52bb75e28 100644
--- a/python/dune/perftool/pdelab/__init__.py
+++ b/python/dune/perftool/pdelab/__init__.py
@@ -31,7 +31,7 @@ from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_weight,
                                              )
 from dune.perftool.pdelab.spaces import (lfs_inames,
                                          )
-from dune.perftool.pdelab.tensors import pymbolic_list_tensor
+from dune.perftool.pdelab.tensors import pymbolic_list_tensor, pymbolic_identity
 
 
 class PDELabInterface(object):
@@ -92,6 +92,9 @@ class PDELabInterface(object):
     def pymbolic_list_tensor(self, o, visitor):
         return pymbolic_list_tensor(o, visitor)
 
+    def pymbolic_identity(self, o):
+        return pymbolic_identity(o)
+
     #
     # Geometry related generator functions
     #
diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py
index 9e98245c3410528780dcf31702dade76751086d2..9da0f1ee86df3e836a8091cc55ddfe3d148e0ab0 100644
--- a/python/dune/perftool/pdelab/geometry.py
+++ b/python/dune/perftool/pdelab/geometry.py
@@ -288,8 +288,12 @@ def name_unit_inner_normal():
 
 
 def type_jacobian_inverse_transposed(restriction):
-    geo = type_cell_geometry(restriction)
-    return "typename {}::JacobianInverseTransposed".format(geo)
+    if get_option('turn_off_diagonal_jacobian'):
+        dim = world_dimension()
+        return "typename Dune::FieldMatrix<double,{},{}>".format(dim,dim)
+    else:
+        geo = type_cell_geometry(restriction)
+        return "typename {}::JacobianInverseTransposed".format(geo)
 
 
 @kernel_cached
@@ -316,10 +320,12 @@ def define_constant_jacobian_inveser_transposed(name, restriction):
 
     globalarg(name, dtype=np.float64, shape=(dim, dim), managed=False)
 
-    return 'auto {} = {}.jacobianInverseTransposed({});'.format(name,
-                                                                geo,
-                                                                pos,
-                                                                )
+    jit_type = type_jacobian_inverse_transposed(restriction)
+    return '{} {} = {}.jacobianInverseTransposed({});'.format(jit_type,
+                                                              name,
+                                                              geo,
+                                                              pos,
+                                                              )
 
 
 @backend(interface="define_jit", name="default")
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 957c767dc9f3d6bae59c084941247f8f057e4a0e..1a5f8faf52f78c1732412ee0f579f53b07a8482b 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -376,8 +376,9 @@ def visit_integrals(integrals):
 
         # Maybe make the jacobian inverse diagonal!
         if get_option('diagonal_transformation_matrix'):
-            from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian
-            integrand = diagonal_jacobian(integrand)
+            if not get_option('turn_off_diagonal_jacobian'):
+                from dune.perftool.ufl.transformations.axiparallel import diagonal_jacobian
+                integrand = diagonal_jacobian(integrand)
 
         # Gather dimension indices
         from dune.perftool.ufl.dimensionindex import dimension_index_mapping
diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py
index 2d8162495cfbdc2f03d85ac31418363d0a1c571a..8e8af15b399e0012dab5016c5ae5185f14eb83f3 100644
--- a/python/dune/perftool/pdelab/parameter.py
+++ b/python/dune/perftool/pdelab/parameter.py
@@ -13,10 +13,7 @@ from dune.perftool.generation import (class_basename,
 from dune.perftool.pdelab.geometry import (name_cell,
                                            name_intersection,
                                            )
-from dune.perftool.pdelab.quadrature import (pymbolic_quadrature_position,
-                                             pymbolic_quadrature_position_in_cell,
-                                             quadrature_preamble,
-                                             )
+from dune.perftool.pdelab.quadrature import quadrature_preamble
 from dune.perftool.tools import get_pymbolic_basename
 from dune.perftool.cgen.clazz import AccessModifier
 from dune.perftool.pdelab.localoperator import (class_type_from_cache,
diff --git a/python/dune/perftool/pdelab/tensors.py b/python/dune/perftool/pdelab/tensors.py
index 315e3ceccee0d4d203bef033c901043cff4efa34..aa154bdb1a41a8dfc54b8faf92043be5a4c0b78a 100644
--- a/python/dune/perftool/pdelab/tensors.py
+++ b/python/dune/perftool/pdelab/tensors.py
@@ -1,7 +1,9 @@
 """ Code generation for explicitly specified tensors """
 
 from dune.perftool.generation import (get_counted_variable,
+                                      domain,
                                       kernel_cached,
+                                      iname,
                                       instruction,
                                       temporary_variable,
                                       )
@@ -33,3 +35,31 @@ def pymbolic_list_tensor(expr, visitor):
                        )
     define_list_tensor(name, expr, visitor)
     return prim.Variable(name)
+
+
+@iname
+def identity_iname(name, bound):
+    name = "id_{}_{}".format(name, bound)
+    domain(name, bound)
+    return name
+
+
+def define_identity(name, expr):
+    i = identity_iname("i", expr.ufl_shape[0])
+    j = identity_iname("j", expr.ufl_shape[1])
+    instruction(assignee=prim.Subscript(prim.Variable(name), (prim.Variable(i), prim.Variable(j))),
+                expression=prim.If(prim.Comparison(prim.Variable(i),"==",prim.Variable(j)),1,0),
+                forced_iname_deps_is_final=True,
+                )
+
+
+@kernel_cached
+def pymbolic_identity(expr):
+    name = "identity_{}_{}".format(expr.ufl_shape[0],expr.ufl_shape[1])
+    temporary_variable(name,
+                       shape=expr.ufl_shape,
+                       shape_impl=('fm',),
+                       dtype=np.float64,
+                       )
+    define_identity(name, expr)
+    return prim.Variable(name)
diff --git a/python/dune/perftool/ufl/extract_accumulation_terms.py b/python/dune/perftool/ufl/extract_accumulation_terms.py
index b5d473b9a3a90cfe72acecc89dcbb0289e41e54a..4dbebbb8233f06cbb42ac273a72d609dc8fb5059 100644
--- a/python/dune/perftool/ufl/extract_accumulation_terms.py
+++ b/python/dune/perftool/ufl/extract_accumulation_terms.py
@@ -12,7 +12,7 @@ from dune.perftool.ufl.transformations.reindexing import reindexing
 from dune.perftool.ufl.modified_terminals import analyse_modified_argument, ModifiedArgument
 from dune.perftool.pdelab.restriction import Restriction
 
-from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product
+from ufl.classes import Zero, Identity, Indexed, IntValue, MultiIndex, Product, IndexSum
 from ufl.core.multiindex import indices
 
 from pytools import Record
@@ -107,6 +107,7 @@ def split_into_accumulation_terms(expr):
         replacement = {}
         indexmap = {}
         newi = None
+        backmap = {}
         # Get all appearances of test functions with their indices
         indexed_test_args = extract_modified_arguments(replace_expr, argnumber=0, do_index=True)
         for indexed_test_arg in indexed_test_args:
@@ -115,20 +116,53 @@ def split_into_accumulation_terms(expr):
                 # -> (m,n) in the example above
                 if newi is None:
                     newi = indices(len(indexed_test_arg.index))
-                # Replace indexed test function with a product of identities.
-                identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
-                                   for i, j in zip(newi, indexed_test_arg.index._indices))
-                replacement.update({indexed_test_arg.expr:
-                                    construct_binary_operator(identities, Product)})
-                indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
-                indexed_test_arg = analyse_modified_argument(reindexing(indexed_test_arg.expr,
-                                                                        replacemap=indexmap))
+
+                # This handles the special case with two identical
+                # indices on an test function. E.g. in Stokes on an
+                # axiparallel grid you get a term:
+                #
+                # -(\sum_i K_{i,i} (\nabla v)_{i,i}) w
+                #   = \sum_k \sum_l (-K_{k,k} w I_{k,l} (\nabla v)_{k,l})
+                #
+                # and we want to split
+                #
+                # -K_{k,k} w I_{k,l} corresponding to (\nabla v)_{k,l}.
+                #
+                # This is done by:
+                # - Replacing (\nabla v)_{i,i} with I_{k,i}*(\nabla
+                #   v)_{k,l}. Here (\nabla v)_{k,l} serves as a
+                #   placeholder and will be replaced later on.
+                # - Propagating the identity in step 4.
+                # - Replacing (\nabla v)_{k,l} by I_{k,l} after step 4.
+                if len(set(indexed_test_arg.index._indices)) < len(indexed_test_arg.index._indices):
+                    if len(indexed_test_arg.index._indices)>2:
+                        raise NotImplementedError("Test argument with more than three indices and double occurence ist not implemented.")
+                    mod_index_map = {indexed_test_arg.index: MultiIndex((newi[0], newi[1]))}
+                    mod_indexed_test_arg = replace_expression(indexed_test_arg.expr,
+                                                              replacemap = mod_index_map)
+                    rep = Product(Indexed(Identity(2),
+                                          MultiIndex((newi[0],indexed_test_arg.index[0]))),
+                                  mod_indexed_test_arg)
+                    backmap.update({mod_indexed_test_arg:
+                                    Indexed(Identity(2), MultiIndex((newi[0],newi[1])))})
+                    replacement.update({indexed_test_arg.expr: rep})
+                    indexmap.update({indexed_test_arg.index[0]: newi[0]})
+                else:
+                    # Replace indexed test function with a product of identities.
+                    identities = tuple(Indexed(Identity(2), MultiIndex((i,) + (j,)))
+                                       for i, j in zip(newi, indexed_test_arg.index._indices))
+                    replacement.update({indexed_test_arg.expr:
+                                        construct_binary_operator(identities, Product)})
+
+                    indexmap.update({i: j for i, j in zip(indexed_test_arg.index._indices, newi)})
             else:
                 replacement.update({indexed_test_arg.expr: IntValue(1)})
         replace_expr = replace_expression(replace_expr, replacemap=replacement)
 
-        # 4) Collapse any identity nodes that may have been introduced by replacing vectors
+        # 4) Collapse any identity nodes that may have been introduced
+        # by replacing vectors and maybe replace placeholder from last step
         replace_expr = identity_propagation(replace_expr)
+        replace_expr = replace_expression(replace_expr, replacemap=backmap)
 
         # 5) Further split according to trial function in jacobian terms
         #
diff --git a/python/dune/perftool/ufl/transformations/axiparallel.py b/python/dune/perftool/ufl/transformations/axiparallel.py
index 30eb43dee2b2d1c2a116b2c86ad8e0d2ad2d4c88..6d7e53e4a2468c0198dabb349017fc885915ef04 100644
--- a/python/dune/perftool/ufl/transformations/axiparallel.py
+++ b/python/dune/perftool/ufl/transformations/axiparallel.py
@@ -13,44 +13,46 @@ from ufl.classes import (Indexed,
                          )
 
 
-class DiagonalJITReplacer(MultiFunction):
-    call = MultiFunction.__call__
+class LocalDiagonalJITReplacer(MultiFunction):
+    """Make JacobianInverse diagonal and simplify tree
 
-    def __call__(self, o, delete, replace):
-        self.delete = delete
-        self.replace = replace
+    Search for:
 
-        return self.call(o)
+    (IndexSum)-L->(Product)-L->(Indexed)-L->(JacobianInverse)
 
-    def expr(self, o):
-        return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
+    and replace it with
 
-    def multi_index(self, o):
-        return MultiIndex(tuple(self.replace.get(i, i) for i in o))
-
-    def index_sum(self, o):
-        if o in self.delete:
-            return self.call(o.ufl_operands[0])
-        else:
-            return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
+    (Product)-L->(Indexed)-L->(JacobianInverse)
 
+    where the index of the initial IndexSum is replaced by the second
+    index of the JacobianInverse.
 
-class DiagonalJITFinder(MultiFunction):
+    Local means: Instead of replacing all appearances of this Index in
+    the graph only this subtree is changed.
+    """
     call = MultiFunction.__call__
 
     def __call__(self, o):
-        self.deleted_index_sums = []
-        self.replacemap = {}
+        self.local_replacemap = {}
+        return self.call(o)
 
-        self.call(o)
+    def expr(self, o):
+        return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
 
-        return DiagonalJITReplacer()(o, self.deleted_index_sums, self.replacemap)
+    def multi_index(self, o):
+        return MultiIndex(tuple(self.local_replacemap.get(i, i) for i in o))
+
+    def indexed(self, o):
+        if self.local_replacemap:
+            if isinstance(o.ufl_operands[1], MultiIndex):
+                return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
+        else:
+            return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
 
-    def expr(self, o):
-        for op in o.ufl_operands:
-            self.call(op)
 
     def index_sum(self, o):
+        # Check if we want to replace an index here and store it in
+        # local_replacemap
         e, i = o.ufl_operands
         if isinstance(e, Product):
             p1, p2 = e.ufl_operands
@@ -60,15 +62,19 @@ class DiagonalJITFinder(MultiFunction):
                     p = p.ufl_operands[0]
                 if isinstance(p, JacobianInverse):
                     assert(i[0] == ii[0])
-                    if ii[0] in self.replacemap:
-                        self.replacemap[ii[1]] = self.replacemap[ii[0]]
-                    else:
-                        self.replacemap[ii[0]] = ii[1]
-                    self.deleted_index_sums.append(o)
-                    return
+                    assert(self.local_replacemap == {})
+                    self.lock_replacemap = True
+                    self.local_replacemap[ii[0]] = ii[1]
+
+        if self.local_replacemap == {}:
+            return self.reuse_if_untouched(o, *tuple(self.call(op) for op in o.ufl_operands))
+        else:
+            # Go through subtrees and replace indices if necessary
+            e = self.call(e)
 
-        for op in o.ufl_operands:
-            self.call(op)
+            # Clear replacemap
+            self.local_replacemap.clear()
+            return e
 
 
 @ufl_transformation(name='axiparallel')
@@ -77,4 +83,4 @@ def diagonal_jacobian(e):
     This transformations can generically be described as:
     \sum_k J^{(+/-)}_{k,i}x_k ==> J^{(+/-)}_{i,i}x_i
     """
-    return DiagonalJITFinder()(e)
+    return LocalDiagonalJITReplacer()(e)
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 503bfd83959e3d3efb7ae597e11b352c4b5f5952..711f92bf6b54a614ddcdd647409e55f310461c08 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -229,6 +229,9 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
     def list_tensor(self, o):
         return self.interface.pymbolic_list_tensor(o, self)
 
+    def identity(self, o):
+        return self.interface.pymbolic_identity(o)
+
     #
     # Handlers for arithmetic operators and functions
     # Those handlers would be valid in any code going from UFL to pymbolic
diff --git a/test/stokes/CMakeLists.txt b/test/stokes/CMakeLists.txt
index d36e95bf33323f370570ed7ff91d8f68264a0341..9d3dd53bf6822563b84d158278e33687222494d9 100644
--- a/test/stokes/CMakeLists.txt
+++ b/test/stokes/CMakeLists.txt
@@ -7,6 +7,11 @@ dune_add_formcompiler_system_test(UFLFILE stokes.ufl
                                   INIFILE stokes.mini
                                   )
 
+dune_add_formcompiler_system_test(UFLFILE stokes_quadrilateral.ufl
+                                  BASENAME stokes_quadrilateral
+                                  INIFILE stokes_quadrilateral.mini
+                                  )
+
 dune_add_formcompiler_system_test(UFLFILE stokes_sym.ufl
                                   BASENAME stokes_sym
                                   INIFILE stokes_sym.mini
diff --git a/test/stokes/stokes_quadrilateral.mini b/test/stokes/stokes_quadrilateral.mini
new file mode 100644
index 0000000000000000000000000000000000000000..0c2a459d2ebdec9279635e47c679d13aff548c89
--- /dev/null
+++ b/test/stokes/stokes_quadrilateral.mini
@@ -0,0 +1,16 @@
+__name = stokes_quadrilateral_{__exec_suffix}
+
+__exec_suffix = {diff_suffix}
+diff_suffix = numdiff, symdiff | expand num
+
+cells = 8 8
+extension = 1. 1.
+
+[wrapper.vtkcompare]
+name = {__name}
+extension = vtu
+
+[formcompiler]
+numerical_jacobian = 1, 0 | expand num
+exact_solution_expression = g
+compare_l2errorsquared = 1e-11
diff --git a/test/stokes/stokes_quadrilateral.ufl b/test/stokes/stokes_quadrilateral.ufl
new file mode 100644
index 0000000000000000000000000000000000000000..0d7c58aa180a9d4586184d1eda6c53353bbbb67a
--- /dev/null
+++ b/test/stokes/stokes_quadrilateral.ufl
@@ -0,0 +1,18 @@
+cell = quadrilateral
+
+x = SpatialCoordinate(cell)
+v_bctype = conditional(x[0] < 1. - 1e-8, 1, 0)
+g_v = as_vector((4.*x[1]*(1.-x[1]), 0.0))
+g_p = 8.*(1.-x[0])
+g = (g_v, g_p)
+
+P2 = VectorElement("Lagrange", cell, 2, dirichlet_constraints=v_bctype, dirichlet_expression=g_v)
+P1 = FiniteElement("Lagrange", cell, 1)
+TH = P2 * P1
+
+v, q = TestFunctions(TH)
+u, p = TrialFunctions(TH)
+
+r = (inner(grad(v), grad(u)) - div(v)*p - q*div(u))*dx
+
+forms = [r]