diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 96ef00127c5848c00866e99a17018be338fe2bed..abefc1859f008deff090b3178dd3a1da1460cac1 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -228,12 +228,16 @@ def _realize_sum_factorization_kernel(sf): direct_output = "{}x{}".format(direct_output, sf.trial_element_index) rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element))) + element = sf.trial_element + if element.num_sub_elements() > 0: + element = element.extract_component(sf.trial_element_index)[1] + other_shape = tuple(element.degree() + 1 for e in range(sf.length)) from pytools import product manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length)) dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides)) globalarg(direct_output, dtype=np.float64, - shape=output_shape + output_shape, + shape=other_shape + output_shape, offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index), dim_tags=dim_tags, )