diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 37e1163086efecb056cd90df3f53ef9b8c9e177c..de9c30aea7dd6a08a35dc702aa3c91442209596f 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -32,7 +32,7 @@ from dune.perftool.pdelab.argument import name_coefficientcontainer
 from dune.perftool.pdelab.geometry import (local_dimension,
                                            world_dimension,
                                            )
-from dune.perftool.loopy.buffer import initialize_buffer
+from dune.perftool.loopy.buffer import initialize_buffer, get_buffer_temporary
 from dune.perftool.sumfact.symbolic import SumfactKernel, SumfactKernelInputBase
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.driver import FEM_name_mangling
@@ -96,7 +96,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
             from dune.perftool.pdelab.argument import pymbolic_coefficient as pc
             coeff = pc(container, lfs, basisiname)
 
-        assignee = prim.Subscript(prim.Variable("input_{}".format(sf.buffer)),
+        # Get the input temporary!
+        name = get_buffer_temporary(sf.buffer,
+                                    shape=(product(mat.basis_size for mat in sf.matrix_sequence), sf.vector_width),
+                                    name="input_{}".format(sf.buffer)
+                                    )
+
+        assignee = prim.Subscript(prim.Variable(name),
                                   (prim.Variable(basisiname),) + (index,))
         instruction(assignee=assignee,
                     expression=coeff,
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 1bd5ea3fab0c579b091e9c70b820587dd5741456..f939feaeb6fc8da05828e88ef25fece71ecf9408 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -61,12 +61,6 @@ def _realize_sum_factorization_kernel(sf):
     if sf.stage == 1 and not get_option("fastdg"):
         assert sf.input
 
-        # Get the input temporary!
-        input_setup = get_buffer_temporary(sf.buffer,
-                                           shape=sf.flat_input_shape,
-                                           name="input_{}".format(sf.buffer)
-                                           )
-
         if sf.vectorized:
             for i, inputsf in enumerate(sf.kernels):
                 inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 88f884c80a54ff82aeeb184b636753becfae4dce..88e82c4aff265f0221608a9afaf18ef5a59abc62 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -15,7 +15,9 @@ import inspect
 
 
 class SumfactKernelInputBase(object):
-    pass
+    @property
+    def flat_shape(self):
+        return False
 
 
 class SumfactKernelBase(object):
@@ -192,12 +194,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         """
         return 0
 
-    @property
-    def flat_input_shape(self):
-        """ The 'flat' input tensor shape """
-        assert self.stage == 1
-        return (product(mat.basis_size for mat in self.matrix_sequence), 1)
-
     @property
     def quadrature_shape(self):
         """ The shape of a temporary for the quadrature points
@@ -283,6 +279,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     def vertical_width(self):
         return 1
 
+    @property
+    def vector_width(self):
+        return 1
+
 # Extract the argument list and store it on the class. This needs to be done
 # outside of the class because the SumfactKernel class object needs to be fully
 # initialized in order to extract the information from __init__.
@@ -470,10 +470,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
         return self.horizontal_index(sf) + prim.Remainder(sliced, self.vertical_width)
 
-    @property
-    def flat_input_shape(self):
-        return (product(mat.basis_size for mat in self.matrix_sequence), self.vector_width)
-
     @property
     def quadrature_shape(self):
         return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)