diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index c063c7e17a963b9daf2bb8a85d7a74b148c88e65..dcc06af6b0ee4121934fc4ecc312493342abc1da 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -272,4 +272,4 @@ def _realize_sum_factorization_kernel(sf): ) silenced_warning('read_no_write({})'.format(out)) - return next(iter(matrix_sequence)).output_to_pymbolic(out), insn_dep + return sf.output_to_pymbolic(out), insn_dep diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 42cf33825c45ff8463da7ed0daa00b619fbc7b2f..0017d467ef1c2126398edf6ed2eb7f9f03def700 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -3,6 +3,7 @@ from pytools import ImmutableRecord import pymbolic.primitives as prim +import loopy as lp class SumfactKernel(ImmutableRecord, prim.Variable): @@ -243,3 +244,9 @@ class SumfactKernel(ImmutableRecord, prim.Variable): return self.quadrature_dimtags else: return self.dof_dimtags + + def output_to_pymbolic(self, name): + if self.vectorized: + return lp.TaggedVariable(name, "vector") + else: + return lp.TaggedVariable(name, "sumfac") diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index 03eb09154f0bbb0fc00e9f4dfec5ed11b612d61c..c6fa99e3b28a1a9f1a499d6559461a5ec1af7cc2 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -57,9 +57,6 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): def vectorized(self): return False - def output_to_pymbolic(self, name): - return lp.TaggedVariable(name, "sumfac") - class BasisTabulationMatrixArray(BasisTabulationMatrixBase): def __init__(self, rows, cols, transpose, derivative, face): @@ -95,9 +92,6 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): def vectorized(self): return True - def output_to_pymbolic(self, name): - return lp.TaggedVariable(name, "vector") - def quadrature_points_per_direction(): # Quadrature order