diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index a2e774e4e566fcdef747da4418539f9f612dc637..86697f9c0793f02624cff08bf95100e28ccff90c 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -103,7 +103,7 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v # with the position in the vector register. if direct_indexing_is_possible: assert len(visitor.indices) == 1 - return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + visitor.indices), None + return maybe_wrap_subscript(var, vsf.quadrature_index(sf, visitor.indices)), None # TODO this should be quite conditional!!! for i, buf in enumerate(buffers): diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 1c92755e19fe763a6582888408dfb68f4aea4f56..fb7e19dfa9369eca12ba9a0f616cb919694a0370 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -439,8 +439,12 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def quadrature_shape(self): return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.horizontal_width,) - def quadrature_index(self, sf): - return self.kernels[0].quadrature_index(sf) + (self.kernels.index(sf),) + def quadrature_index(self, sf, direct_index=None): + if direct_index is not None: + assert isinstance(direct_index, tuple) + return self.kernels[0].quadrature_index(sf) + direct_index + else: + return self.kernels[0].quadrature_index(sf) + (self.kernels.index(sf),) @property def quadrature_dimtags(self):