diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 28791c17bdae29dd30d5718cac4cfdb829b535b4..0ee51952d2c49adc09299996e430d6675da62524 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -394,6 +394,20 @@ def _sf_permutation_strategy(a_matrices, stage):
     return perm
 
 
+def _permute_forward(t, perm):
+    tmp = []
+    for pos in perm:
+        tmp.append(t[pos])
+    return tuple(tmp)
+
+
+def _permute_backward(t, perm):
+    tmp = [None]*len(t)
+    for i, pos in enumerate(perm):
+        tmp[pos] = t[i]
+    return tuple(tmp)
+
+
 @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
 def sum_factorization_kernel(a_matrices,
                              buf,
@@ -437,6 +451,15 @@ def sum_factorization_kernel(a_matrices,
     Note: In the code below the transformation step is directly done
     in the reduction instruction by adapting the assignee!
 
+    It can make sense to permute the order of directions. If you have
+    a small m_l (e.g. stage 1 on faces) it is better to do direction l
+    first. This can be done permuting:
+
+    - The order of the A matrices.
+    - Permuting the input tensor.
+    - Permuting the output tensor (this assures that the directions of
+      the output tensor are again ordered from 0 to d-1).
+
     Arguments:
     ----------
     a_matrices: An iterable of AMatrix instances
@@ -490,38 +513,10 @@ def sum_factorization_kernel(a_matrices,
     # face.
     #
     # Rule of thumb: small m's early and large n's late.
-
-    # palpo TODO
-    if stage==3 and outshape!=None:
-        from IPython import embed; embed(); import sys; sys.exit("Error message")
-
-    if stage==1 or stage==3:
-        # perm = range(len(a_matrices))
-        perm = _sf_permutation_strategy(a_matrices, stage)
-    else:
-        perm = range(len(a_matrices))
-
-    # # palpo TODO
-    # print("##  PALPO")
-    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
-    # print(shape)
+    perm = _sf_permutation_strategy(a_matrices, stage)
 
     # Permute a_matrices
-    new_a_matrices = []
-    for pos in perm:
-        new_a_matrices.append(a_matrices[pos])
-    a_matrices = tuple(new_a_matrices)
-
-    # # palpo TODO
-    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
-    # print(shape)
-
-    # new_a_matrices = [None]*len(a_matrices)
-    # for i, pos in enumerate(perm):
-    #     new_a_matrices[pos] = a_matrices[i]
-    # a_matrices = tuple(new_a_matrices)
-    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
-    # print(shape)
+    a_matrices = _permute_forward(a_matrices, perm)
 
     # Product of all matrices
     for l, a_matrix in enumerate(a_matrices):
@@ -565,32 +560,18 @@ def sum_factorization_kernel(a_matrices,
                 input_summand = prim.Subscript(prim.Variable(direct_input),
                                                palpo + vec_iname)
         else:
+            # If we did permute the order of a matrices above we also
+            # permuted the order of out_inames. Unfortunately the
+            # order of our input is from 0 to d-1. This means we need
+            # to permute _back_ to get the right coefficients.
+            input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
+            if l == 0:
+                inp_shape = _permute_backward(inp_shape, perm)
+                input_inames = _permute_backward(input_inames, perm)
+
             # Get a temporary that interprets the base storage of the input
             # as a column-major matrix. In later iteration of the amatrix loop
             # this reinterprets the output of the previous iteration.
-            palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
-            if l==0:
-                tmp_perm = [None]*len(inp_shape)
-                for i, pos in enumerate(perm):
-                    tmp_perm[pos] = inp_shape[i]
-                inp_shape = tuple(tmp_perm)
-
-                tmp_perm = [None]*len(palpo)
-                for i, pos in enumerate(perm):
-                    tmp_perm[pos] = palpo[i]
-                palpo = tuple(tmp_perm)
-
-            # tmp_perm = []
-            # for pos in perm:
-            #     tmp_perm.append(inp_shape[pos])
-            # inp_shape = tuple(tmp_perm)
-
-            # palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
-            # tmp_perm = []
-            # for pos in perm:
-            #     tmp_perm.append(palpo[pos])
-            # palpo = tuple(tmp_perm)
-
             inp = get_buffer_temporary(buf,
                                        shape=inp_shape + vec_shape,
                                        dim_tags=ftags)
@@ -599,7 +580,7 @@ def sum_factorization_kernel(a_matrices,
             silenced_warning('read_no_write({})'.format(inp))
 
             input_summand = prim.Subscript(prim.Variable(inp),
-                                           palpo + vec_iname)
+                                           input_inames + vec_iname)
 
         switch_base_storage(buf)
 
@@ -610,32 +591,16 @@ def sum_factorization_kernel(a_matrices,
         # corresponding shape (out_shape[0]) goes to the end (slowest
         # direction) and everything stays column major (ftags->fortran
         # style).
-        # if False:
+        #
+        # If we are in the last step we reverse the permutation.
+        output_shape = tuple(out_shape[1:]) + (out_shape[0],)
         if l == len(a_matrices)-1:
-            out_shape = tuple(out_shape[1:]) + (out_shape[0],)
-            tmp_perm = [None]*len(out_shape)
-            for i, pos in enumerate(perm):
-                tmp_perm[pos] = out_shape[i]
-            out_shape = tuple(tmp_perm)
-
-            # tmp_perm = []
-            # for pos in perm:
-            #     tmp_perm.append(out_shape[pos])
-            # out_shape = tuple(tmp_perm)
-
-
-            out = get_buffer_temporary(buf,
-                                       shape=out_shape + vec_shape,
-                                       dim_tags=ftags)
-        else:
-            out = get_buffer_temporary(buf,
-                                       shape=tuple(out_shape[1:]) + (out_shape[0],) + vec_shape,
-                                       dim_tags=ftags)
+            output_shape = _permute_backward(output_shape, perm)
+        out = get_buffer_temporary(buf,
+                                   shape=output_shape + vec_shape,
+                                   dim_tags=ftags)
 
         # Write the matrix-matrix multiplication expression
-        # matprod = Product((prim.Subscript(prim.Variable(a_matrix.name),
-        #                                   (prim.Variable(i), k_expr) + vec_iname),
-        #                    input_summand))
         matprod = Product((prim.Subscript(prim.Variable(a_matrix.name),
                                           (prim.Variable(out_inames[0]), k_expr) + vec_iname),
                            input_summand))
@@ -644,24 +609,12 @@ def sum_factorization_kernel(a_matrices,
         if a_matrix.cols != 1:
             matprod = lp.Reduction("sum", k, matprod)
 
-        # Here we also move the new direction (out_inames[0]) to the end
-        # if False:
+        # Here we also move the new direction (out_inames[0]) to the
+        # end and reverse permutation
+        output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
         if l == len(a_matrices)-1:
-            palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
-            tmp_perm = [None]*len(palpo)
-            for i, pos in enumerate(perm):
-                tmp_perm[pos] = palpo[i]
-            palpo = tuple(tmp_perm)
-
-            # palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
-            # tmp_perm = []
-            # for pos in perm:
-            #     tmp_perm.append(palpo[pos])
-            # palpo = tuple(tmp_perm)
-
-            assignee = prim.Subscript(prim.Variable(out), palpo + vec_iname)
-        else:
-            assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname)
+            output_inames = _permute_backward(output_inames, perm)
+        assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname)
 
         # Issue the reduction instruction that implements the multiplication
         # at the same time store the instruction ID for the next instruction to depend on