diff --git a/python/dune/perftool/loopy/buffer.py b/python/dune/perftool/loopy/buffer.py
index ff806da7ea74dc9f2596419589c3cb69f48d242f..edb1d378f2c5ea680f18dc7c9103486f74ef2d1f 100644
--- a/python/dune/perftool/loopy/buffer.py
+++ b/python/dune/perftool/loopy/buffer.py
@@ -42,6 +42,7 @@ class FlipFlopBuffer(object):
 
 @kernel_cached
 def initialize_buffer(identifier):
+    assert isinstance(identifier, str)
     return FlipFlopBuffer(identifier)
 
 
diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
index 5cd5c502430f78640690af85fc281360867225b2..345819e5ab55d09650c4b90e92a6ad0fa39930fc 100644
--- a/python/dune/perftool/loopy/symbolic.py
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -28,6 +28,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
                  padding=frozenset(),
                  index=None,
                  insn_dep=frozenset(),
+                 coeff_func=None,
+                 element=None,
+                 component=None,
                  ):
         # Check the input and apply defaults where necessary
         assert isinstance(a_matrices, tuple)
@@ -39,8 +42,8 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
         if preferred_position is not None:
             assert isinstance(preferred_position, int)
 
-        if not isinstance(restriction, tuple):
-            restriction = (restriction, 0)
+        if stage == 3:
+            assert isinstance(restriction, tuple)
 
         assert isinstance(within_inames, tuple)
 
@@ -57,6 +60,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
                                  padding=padding,
                                  index=index,
                                  insn_dep=insn_dep,
+                                 coeff_func=coeff_func,
+                                 element=element,
+                                 component=component,
                                  )
 
         prim.Variable.__init__(self, "SUMFACT")
@@ -65,12 +71,12 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
     # The methods/fields needed to get a well-formed pymbolic node
     #
     def __getinitargs__(self):
-        return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index, self.insn_dep)
+        return (self.a_matrices, self.buffer, self.stage, self.preferred_position, self.restriction, self.within_inames, self.input, self.padding, self.index, self.insn_dep, self.coeff_func, self.element, self.component)
 
     def stringifier(self):
         return lp.symbolic.StringifyMapper
 
-    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index", "insn_dep")
+    init_arg_names = ("a_matrices", "buffer", "stage", "preferred_position", "restriction", "within_inames", "input", "padding", "index", "insn_dep", "coeff_func", "element", "component")
 
     mapper_method = "map_sumfact_kernel"
 
@@ -94,6 +100,15 @@ class SumfactKernel(ImmutableRecord, prim.Variable):
         """
         return hash((self.a_matrices, self.restriction, self.stage, self.buffer))
 
+    @property
+    def flat_input_shape(self):
+        """ The 'flat' input tensor shape """
+        from pytools import product
+        shape = (product(mat.cols for mat in self.a_matrices),)
+        if self.vectorized:
+            shape = shape + (4,)
+        return shape
+
 
 class FusedMultiplyAdd(prim.Expression):
     """ Represents an FMA operation """
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 413f3039b5a7fdb8f43e00be246096fddffe09b8..5ed694961513990b6f8ed01a9baf90bcd9a451a9 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -24,7 +24,6 @@ from dune.perftool.sumfact.sumfact import (get_facedir,
                                            setup_theta,
                                            SumfactKernel,
                                            sumfact_iname,
-                                           sum_factorization_kernel,
                                            )
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.switch import (get_facedir,
@@ -78,6 +77,9 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         sf = SumfactKernel(a_matrices=a_matrices,
                            restriction=restriction,
                            preferred_position=i,
+                           coeff_func=coeff_func,
+                           element=element,
+                           component=component,
                            )
 
         from dune.perftool.sumfact.vectorization import attach_vectorization_info
@@ -91,19 +93,19 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         index = sf.index
         padding = sf.padding
 
-        if buf is None:
-            buf = get_counted_variable("buffer")
-        if inp is None:
-            inp = get_counted_variable("input")
-
-        # Initialize the buffer for the sum fact kernel
-        shape = (product(mat.cols for mat in a_matrices),)
-        if index is not None:
-            shape = shape + (4,)
-        inp = initialize_buffer(buf).get_temporary(shape=shape,
-                                                   name=inp,
-                                                   )
-        insn_dep = frozenset({Writes(inp)})
+#        if buf is None:
+#            buf = get_counted_variable("buffer")
+#        if inp is None:
+#            inp = get_counted_variable("input")
+#
+#        # Initialize the buffer for the sum fact kernel
+#        shape = (product(mat.cols for mat in a_matrices),)
+#        if index is not None:
+#            shape = shape + (4,)
+#        inp = initialize_buffer(buf).get_temporary(shape=shape,
+#                                                   name=inp,
+#                                                   )
+#        insn_dep = frozenset({Writes(inp)})
 
         if get_option('fastdg'):
             # Name of direct input, shape and globalarg is set in sum_factorization_kernel
@@ -111,29 +113,16 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v
         else:
             direct_input = None
             # Setup the input!
-            setup_theta(inp, element, restriction, component, index, coeff_func)
+            #setup_theta(inp, element, restriction, component, index, coeff_func)
 
         # Add a sum factorization kernel that implements the
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
-        if not get_global_context_value("dry_run", False):
-            from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
-            var, insn_dep = realize_sum_factorization_kernel(sf,
-                                                 insn_dep=insn_dep,
+        from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
+        var, insn_dep = realize_sum_factorization_kernel(sf,
                                                  outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
                                                  direct_input=direct_input,
                                                  )
-#            var, insn_dep = sum_factorization_kernel(a_matrices,
-#                                                 buf,
-#                                                 1,
-#                                                 preferred_position=i,
-#                                                 insn_dep=insn_dep,
-#                                                 restriction=restriction,
-#                                                 outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
-#                                                 direct_input=direct_input,
-#                                                 )
-        else:
-            var = sf
 
         buffers.append(var)
 
@@ -171,6 +160,9 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
 
     sf = SumfactKernel(a_matrices=a_matrices,
                        restriction=restriction,
+                       coeff_func=coeff_func,
+                       element=element,
+                       component=component,
                        )
 
     # TODO: Move this away!
@@ -184,19 +176,19 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     inp = sf.input
     index = sf.index
     padding = sf.padding
-
-    if buf is None:
-        buf = get_counted_variable("buffer")
-    if inp is None:
-        inp = get_counted_variable("input")
-
-    # Flip flop buffers for sumfactorization
-    shape = (product(mat.cols for mat in a_matrices),)
-    if index is not None:
-        shape = shape + (4,)
-    initialize_buffer(buf).get_temporary(shape=shape,
-                                      name=inp,
-                                      )
+#
+#    if buf is None:
+#        buf = get_counted_variable("buffer")
+#    if inp is None:
+#        inp = get_counted_variable("input")
+#
+#    # Flip flop buffers for sumfactorization
+#    shape = (product(mat.cols for mat in a_matrices),)
+#    if index is not None:
+#        shape = shape + (4,)
+#    initialize_buffer(buf).get_temporary(shape=shape,
+#                                      name=inp,
+#                                      )
 
     if get_option('fastdg'):
         # Name of direct input, shape and globalarg is set in sum_factorization_kernel
@@ -204,14 +196,14 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor):
     else:
         direct_input = None
         # Setup the input!
-        setup_theta(inp, element, restriction, component, index, coeff_func)
+        # setup_theta(inp, element, restriction, component, index, coeff_func)
 
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
     if not get_global_context_value("dry_run", False):
         from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
         var, _ = realize_sum_factorization_kernel(sf,
-                                                  insn_dep=frozenset({Writes(inp)}),
+#                                                  insn_dep=frozenset({Writes(inp)}),
                                                   outshape=tuple(mat.rows for mat in a_matrices if mat.face is None),
                                                   direct_input=direct_input,
                                                   )
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 0fb0638a8fcd651205fc1833b11bac5a54f61b54..9032ad1b0b642dbb3182dd8c4938ec9008b8683f 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -16,7 +16,9 @@ from dune.perftool.generation import (barrier,
 from dune.perftool.loopy.buffer import (get_buffer_temporary,
                                         switch_base_storage,
                                         )
+from dune.perftool.pdelab.argument import pymbolic_coefficient
 from dune.perftool.pdelab.geometry import world_dimension
+from dune.perftool.pdelab.spaces import name_lfs, name_lfs_bound
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.signatures import assembler_routine_name
 from dune.perftool.sumfact.permutation import (_sf_permutation_strategy,
@@ -40,20 +42,40 @@ def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, di
         insn_dep = frozenset({insn_dep})
     assert isinstance(insn_dep, frozenset)
 
+    # Get the vectorization info. During dry run, this is a now op
+#    sf = attach_vectorization_info(sf)
     if get_global_context_value("dry_run", False):
         # During the dry run, we just return the kernel as passed into this
         # function. After the dry run, it can be used to attach information
         # about vectorization.
         return sf, insn_dep
-#    else:
-#        # This is the second run: Retrieve the vectorization information
-#        # attached in dune.perftool.sumfact.vectorization
-#        sf = attach_vectorization_info(sf)
 
     # Get the instruction dependencies of the sumfact kernel. This variable will be
     # updated throughout this function.
     insn_dep = insn_dep.union(sf.insn_dep)
 
+    # Define some helper constructs that make it easier to write generic code later
+    vecindex = () if sf.index is None else (sf.index,)
+
+    # Set up the input for stage 1
+    if sf.stage == 1 and not get_option("fastdg"):
+        assert sf.coeff_func
+
+        # Get the input temporary!
+        input_setup = get_buffer_temporary(sf.buffer,
+                                           shape=sf.flat_input_shape,
+                                           )
+
+        # Write initial coefficients into buffer
+        lfs = name_lfs(sf.element, sf.restriction, sf.component)
+        basisiname = sumfact_iname(name_lfs_bound(lfs), "basis")
+        container = sf.coeff_func(sf.restriction)
+        coeff = pymbolic_coefficient(container, lfs, basisiname)
+        assignee = prim.Subscript(prim.Variable(input_setup), (prim.Variable(basisiname),) + vecindex)
+        insn_dep = instruction(assignee=assignee,
+                               expression=coeff,
+                               )
+
     # Prepare some dim_tags/shapes for later use
     ftags = ",".join(["f"] * sf.length)
     novec_ftags = ftags
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 7fa0870786cad38a70e126172c51d7c8018a25e9..3bf4668c143275b8f53318e345dbffe96a1908b6 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -184,7 +184,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             index = ()
             vectag = frozenset()
 
-        base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
         temp = initialize_buffer(buf).get_temporary(shape=shape,
                                                  dim_tags=dim_tags,
                                                  name=inp,
@@ -356,19 +355,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
 
 
-@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
-def sum_factorization_kernel(a_matrices,
-                             buf,
-                             stage,
-                             insn_dep=frozenset({}),
-                             additional_inames=frozenset({}),
-                             preferred_position=None,
-                             outshape=None,
-                             restriction=0,
-                             direct_input=None,
-                             direct_output=None,
-                             visitor=None,
-                             ):
     """Create a sum factorization kernel
 
     Sum factorization can be written as
@@ -430,214 +416,4 @@ def sum_factorization_kernel(a_matrices,
     restriction: Restriction for faces values.
     direct_input: Global data structure containing input for
         sumfactorization (e.g. when using FastDGGridOperator).
-    """
-    # Return a pymbolic SumfactKernel node in the dry run. This will
-    # be used to decide on an appropriate vectorization strategy
-    # before we do the real thing.
-    if get_global_context_value("dry_run", False):
-        return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset()
-
-    ftags = ",".join(["f"] * len(a_matrices))
-    novec_ftags = ftags
-    ctags = ",".join(["c"] * len(a_matrices))
-    vec_shape = ()
-    if next(iter(a_matrices)).vectorized:
-        ftags = ftags + ",vec"
-        ctags = ctags + ",vec"
-        vec_shape = (4,)
-
-    # Measure times and count operations in c++ code
-    if get_option("instrumentation_level") >= 4:
-        timer_name = assembler_routine_name() + '_kernel' + '_stage{}'.format(stage)
-        post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
-        dump_accumulate_timer(timer_name)
-        insn_dep = frozenset({instruction(code="HP_TIMER_START({});".format(timer_name),
-                                          depends_on=insn_dep,
-                                          within_inames=additional_inames)})
-
-    # Put a barrier before the sumfactorization kernel
-    insn_dep = frozenset({barrier(depends_on=insn_dep,
-                                  within_inames=additional_inames,
-                                  )})
-
-    # Decide in which order we want to process directions in the
-    # sumfactorization. A clever ordering can lead to a reduced
-    # complexity. This will e.g. happen at faces where we only have
-    # one quadratue point m_l=1 if l is the normal direction of the
-    # face.
-    #
-    # Rule of thumb: small m's early and large n's late.
-    perm = _sf_permutation_strategy(a_matrices, stage)
-
-    # Permute a_matrices
-    a_matrices = _permute_forward(a_matrices, perm)
-
-    # Product of all matrices
-    for l, a_matrix in enumerate(a_matrices):
-        # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication
-        # and get inames that realize the product.
-        inp_shape = (a_matrix.cols,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l])
-        out_shape = (a_matrix.rows,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l])
-        out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape))
-        vec_iname = ()
-        if a_matrix.vectorized:
-            iname = sumfact_iname(4, "vec")
-            vec_iname = (prim.Variable(iname),)
-            transform(lp.tag_inames, [(iname, "vec")])
-
-        # A trivial reduction is implemented as a product, otherwise we run into
-        # a code generation corner case producing way too complicated code. This
-        # could be fixed upstream, but the loopy code realizing reductions is not
-        # trivial and the priority is kind of low.
-        if a_matrix.cols != 1:
-            k = sumfact_iname(a_matrix.cols, "red")
-            k_expr = prim.Variable(k)
-        else:
-            k_expr = 0
-
-        # Setup the input of the sum factorization kernel. In the
-        # first matrix multiplication this can be taken from
-        # * an input temporary (default)
-        # * a global data structure (if FastDGGridOperator is in use)
-        # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
-        input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
-        if l == 0 and direct_input is not None:
-            # See comment below
-            input_inames = _permute_backward(input_inames, perm)
-            inp_shape = _permute_backward(inp_shape, perm)
-
-            globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags)
-            if a_matrix.vectorized:
-                input_summand = prim.Call(prim.Variable("Vec4d"),
-                                          (prim.Subscript(prim.Variable(direct_input),
-                                                          input_inames),))
-            else:
-                input_summand = prim.Subscript(prim.Variable(direct_input),
-                                               input_inames + vec_iname)
-        else:
-            # If we did permute the order of a matrices above we also
-            # permuted the order of out_inames. Unfortunately the
-            # order of our input is from 0 to d-1. This means we need
-            # to permute _back_ to get the right coefficients.
-            if l == 0:
-                inp_shape = _permute_backward(inp_shape, perm)
-                input_inames = _permute_backward(input_inames, perm)
-
-            # Get a temporary that interprets the base storage of the input
-            # as a column-major matrix. In later iteration of the amatrix loop
-            # this reinterprets the output of the previous iteration.
-            inp = get_buffer_temporary(buf,
-                                       shape=inp_shape + vec_shape,
-                                       dim_tags=ftags)
-
-            # The input temporary will only be read from, so we need to silence the loopy warning
-            silenced_warning('read_no_write({})'.format(inp))
-
-            input_summand = prim.Subscript(prim.Variable(inp),
-                                           input_inames + vec_iname)
-
-        switch_base_storage(buf)
-
-        # Get a temporary that interprets the base storage of the output.
-        #
-        # Note: In this step the reordering of the fastest directions
-        # is happening. The new direction (out_inames[0]) and the
-        # corresponding shape (out_shape[0]) goes to the end (slowest
-        # direction) and everything stays column major (ftags->fortran
-        # style).
-        #
-        # If we are in the last step we reverse the permutation.
-        output_shape = tuple(out_shape[1:]) + (out_shape[0],)
-        if l == len(a_matrices) - 1:
-            output_shape = _permute_backward(output_shape, perm)
-        out = get_buffer_temporary(buf,
-                                   shape=output_shape + vec_shape,
-                                   dim_tags=ftags)
-
-        # Write the matrix-matrix multiplication expression
-        matprod = Product((prim.Subscript(prim.Variable(a_matrix.name),
-                                          (prim.Variable(out_inames[0]), k_expr) + vec_iname),
-                           input_summand))
-
-        # ... which may be a reduction, if k>0
-        if a_matrix.cols != 1:
-            matprod = lp.Reduction("sum", k, matprod)
-
-        # Here we also move the new direction (out_inames[0]) to the
-        # end and reverse permutation
-        output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
-        if l == len(a_matrices) - 1:
-            output_inames = _permute_backward(output_inames, perm)
-
-        # In case of direct output we directly accumulate the result
-        # of the Sumfactorization into some global data structure.
-        if l == len(a_matrices) - 1 and direct_output is not None:
-            ft = get_global_context_value("form_type")
-            if ft == 'residual' or ft == 'jacobian_apply':
-                globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags)
-                assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
-            else:
-                assert ft == 'jacobian'
-                globalarg(direct_output,
-                          dtype=np.float64,
-                          shape=output_shape + output_shape,
-                          dim_tags=novec_ftags + "," + novec_ftags)
-                # TODO the next line should get its inames from
-                # elsewhere. This is *NOT* robust (but works right
-                # now)
-                _ansatz_inames = tuple(Variable(visitor.inames[i]) for i in range(world_dimension()))
-                assignee = prim.Subscript(prim.Variable(direct_output), _ansatz_inames + output_inames)
-
-            # In case of vectorization we need to apply a horizontal add
-            if a_matrix.vectorized:
-                matprod = prim.Call(prim.Variable("horizontal_add"),
-                                    (matprod,))
-
-            # We need to accumulate
-            matprod = prim.Sum((assignee, matprod))
-        else:
-            assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname)
-
-        # Issue the reduction instruction that implements the multiplication
-        # at the same time store the instruction ID for the next instruction to depend on
-        insn_dep = frozenset({instruction(assignee=assignee,
-                                          expression=matprod,
-                                          forced_iname_deps=frozenset([iname for iname in out_inames]).union(additional_inames),
-                                          forced_iname_deps_is_final=True,
-                                          depends_on=insn_dep,
-                                          )
-                              })
-
-    # Measure times and count operations in c++ code
-    if get_option("instrumentation_level") >= 4:
-        insn_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
-                                          depends_on=insn_dep,
-                                          within_inames=additional_inames)})
-        if stage == 1:
-            qp_timer_name = assembler_routine_name() + '_kernel' + '_quadratureloop'
-            post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile')
-            dump_accumulate_timer(timer_name)
-            insn_dep = instruction(code="HP_TIMER_START({});".format(qp_timer_name),
-                                   depends_on=insn_dep)
-
-    if outshape is None:
-        assert stage == 3
-        outshape = tuple(mat.rows for mat in a_matrices)
-
-    dim_tags = ",".join(['f'] * len(outshape))
-
-    if next(iter(a_matrices)).vectorized:
-        outshape = outshape + vec_shape
-        # This is a 'bit' hacky: In stage 3 we need to return something with vectag, in stage 1 not.
-        if stage == 1:
-            dim_tags = dim_tags + ",c"
-        else:
-            dim_tags = dim_tags + ",vec"
-
-    out = get_buffer_temporary(buf,
-                               shape=outshape,
-                               dim_tags=dim_tags,
-                               )
-    silenced_warning('read_no_write({})'.format(out))
-
-    return next(iter(a_matrices)).output_to_pymbolic(out), insn_dep
+    """
\ No newline at end of file
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 8e9b0d076cb945c466e9344fe07fe44a5c4815f4..2aa07d97e17ecc82a7c260a47b1e7415f7115a1d 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -21,10 +21,12 @@ def _cache_vectorization_info(old, new):
     return new
 
 
+_collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), no_deco=True)
+
 def attach_vectorization_info(sf):
     assert isinstance(sf, SumfactKernel)
     if get_global_context_value("dry_run"):
-        return sf
+        return _collect_sumfact_nodes(sf)
     else:
         return _cache_vectorization_info(sf, None)
 
@@ -110,11 +112,15 @@ def decide_vectorization_strategy():
     if not get_option("vectorize_grads"):
         no_vectorization(sumfacts)
     else:
-        for stage in (1, 3):
-            res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
-            import itertools as it
-            for restriction in it.product(res, res):
-                decide_stage_vectorization_strategy(sumfacts, stage, restriction)
+        res = (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE)
+        # Stage 1 kernels
+        for restriction in res:
+            decide_stage_vectorization_strategy(sumfacts, 1, restriction)
+
+        # Stage 3 kernels
+        import itertools as it
+        for restriction in it.product(res, res):
+            decide_stage_vectorization_strategy(sumfacts, 3, restriction)
 
 
 class HasSumfactMapper(lp.symbolic.CombineMapper):
@@ -133,6 +139,9 @@ class HasSumfactMapper(lp.symbolic.CombineMapper):
     def map_sumfact_kernel(self, expr):
         return frozenset({expr})
 
+    def map_tagged_variable(self, expr):
+        return frozenset()
+
 
 def find_sumfact(expr):
     return HasSumfactMapper()(expr)