diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 96ef00127c5848c00866e99a17018be338fe2bed..abefc1859f008deff090b3178dd3a1da1460cac1 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -228,12 +228,16 @@ def _realize_sum_factorization_kernel(sf):
 
                 direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
                 rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
+                element = sf.trial_element
+                if element.num_sub_elements() > 0:
+                    element = element.extract_component(sf.trial_element_index)[1]
+                other_shape = tuple(element.degree() + 1 for e in range(sf.length))
                 from pytools import product
                 manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length))
                 dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides))
                 globalarg(direct_output,
                           dtype=np.float64,
-                          shape=output_shape + output_shape,
+                          shape=other_shape + output_shape,
                           offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index),
                           dim_tags=dim_tags,
                           )