diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index de9c30aea7dd6a08a35dc702aa3c91442209596f..04bf64bad3349da40f153261cf308f0ba6a2edb7 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -110,6 +110,13 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
                     tags=frozenset({"sumfact_stage{}".format(sf.stage)}),
                     )
 
+    @property
+    def direct_input(self):
+        if get_option("fastdg"):
+            return self.coeff_func(self.restriction)
+        else:
+            return None
+
 
 def _basis_functions_per_direction(element, component):
     """Number of basis functions per direction of a given component of an element"""
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index f939feaeb6fc8da05828e88ef25fece71ecf9408..8922a0b06c8bd5c0e7bc21afe0585b08bf804ab3 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -57,10 +57,10 @@ def _realize_sum_factorization_kernel(sf):
                                                          ),
                                              }))
 
-    # Set up the input for stage 1
-    if sf.stage == 1 and not get_option("fastdg"):
-        assert sf.input
+    direct_input = sf.input.direct_input
 
+    # Set up the input for stage 1
+    if direct_input is None:
         if sf.vectorized:
             for i, inputsf in enumerate(sf.kernels):
                 inputsf.input.realize(sf, i, inputsf.insn_dep.union(insn_dep))
@@ -69,11 +69,6 @@ def _realize_sum_factorization_kernel(sf):
 
         insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
 
-    # Construct the direct_input for the FastDG case
-    direct_input = None
-    if get_option('fastdg') and sf.stage == 1:
-        direct_input = sf.input.coeff_func(sf.input.restriction)
-
     direct_output = None
     if get_option('fastdg') and sf.stage == 3:
         direct_output = sf.accumvar + ".data()"
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index a8a42ed170ea97eead8741183d4145f0d5b7ba3e..9a3011cba1b3223e5af7b081a89cecd9d6326549 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -16,8 +16,11 @@ import inspect
 
 class SumfactKernelInputBase(object):
     @property
-    def flat_shape(self):
-        return False
+    def direct_input(self):
+        return None
+
+    def realize(self, sf, i, dep):
+        pass
 
 
 class SumfactKernelBase(object):