diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index 50f4d169935cae6ec421da62ba432506d055f298..c9c32b482a84fb32aafd7ee1a546b8c063c362a7 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -6,6 +6,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable,
 from dune.perftool.generation import (backend,
                                       domain,
                                       dump_accumulate_timer,
+                                      generator_factory,
                                       get_counted_variable,
                                       get_counter,
                                       iname,
@@ -40,10 +41,29 @@ from pytools import ImmutableRecord
 import loopy as lp
 import numpy as np
 import pymbolic.primitives as prim
+from pymbolic.mapper import WalkMapper
 import ufl.classes as uc
 from ufl import FiniteElement, MixedElement, TensorProductElement
 
 
+basis_sf_kernels = generator_factory(item_tags=("basis_sf_kernels",), context_tags='kernel', no_deco=True)
+
+
+class SumfactCollectMapper(WalkMapper):
+    def map_tagged_variable(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_variable(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_sumfact_kernel(self, expr, *args, **kwargs):
+        basis_sf_kernels(expr)
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
@@ -199,6 +219,9 @@ def generate_accumulation_instruction(expr, visitor):
     test_info = visitor.test_info
     trial_info = visitor.trial_info
 
+    # Cache all stage 1 sum factorization kernels used in this expression
+    SumfactCollectMapper()(expr)
+
     # Number of basis functions per direction
     leaf_element = test_info.element
     from ufl import MixedElement
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 8e745c7daf65294db5e9f34b03c7a19736918f8c..00ef29661f33d4c7f32600babe622789caca50d3 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -146,6 +146,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    # If this sum factorization kernel was not used in the dry run we
+    # just return 0
+    if vsf == 0:
+        return 0, None
+
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     var, insn_dep = realize_sum_factorization_kernel(vsf)
 
@@ -182,6 +187,11 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    # If this sum factorization kernel was not used in the dry run we
+    # just return 0
+    if vsf == 0:
+        return 0, None
+
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 3f1afe9c3bab2a9164d315170b3adddf3a8e25a5..5c9e4a5b50fd07d39c1d4639982d05ca82a40224 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -24,7 +24,6 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_backward,
                                                permute_forward,
                                                )
-from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.sumfact.accumulation import sumfact_iname
 from dune.perftool.loopy.vcl import ExplicitVCLCast
 
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 08e07782c41e9a9928e9d06854ea7ef80a4a3414..3a0447945be53aea13c3f1188a78241e883dede3 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -140,37 +140,52 @@ def decide_vectorization_strategy():
     """
     logger = logging.getLogger(__name__)
 
+    # Retrieve all sum factorization kernels for stage 1 and 3
     from dune.perftool.generation import retrieve_cache_items
-    sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
+    all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
+
+    # Stage 1 sumfactorizations that were actually used
+    basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')]
+
+    # This means we can have sum factorizations that will not get used
+    inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts]
+
+    # All sum factorization kernels that get used
+    active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts]
+
+    # We map inacitve sum factorizatino kernels to 0
     sfdict = {}
+    for sf in inactive_sumfacts:
+        sfdict[sf] = 0
 
-    logger.debug("decide_vectorization_strategy: Found {} sum factorization nodes".format(len(sumfacts)))
+    logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes"
+                 .format(len(active_sumfacts)))
 
     if get_option("vectorize_grads"):
         # Currently we base our idea here on the fact that we only group sum
         # factorization kernels with the same input.
-        inputkeys = set(sf.input_key for sf in sumfacts)
+        inputkeys = set(sf.input_key for sf in active_sumfacts)
         for inputkey in inputkeys:
             width = get_vcl_type_size(np.float64)
-            sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
+            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
             for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items():
                 sfdict[old] = new
     elif get_option("vectorize_slice"):
-        for sumfact in sumfacts:
+        for sumfact in active_sumfacts:
             width = get_vcl_type_size(np.float64)
             for old, new in vertical_vectorization_strategy(sumfact, width).items():
                 sfdict[old] = new
     elif get_option("vectorize_diagonal"):
-        inputkeys = set(sf.input_key for sf in sumfacts)
+        inputkeys = set(sf.input_key for sf in active_sumfacts)
         for inputkey in inputkeys:
             width = get_vcl_type_size(np.float64)
-            sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
+            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
             for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items():
                 sfdict[old] = new
     else:
-        for old, new in no_vectorization(sumfacts).items():
+        for old, new in no_vectorization(active_sumfacts).items():
             sfdict[old] = new
 
     # Register the results
-    for sf in sumfacts:
+    for sf in all_sumfacts:
         _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf)))