diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index 4d72961372b7bde1af3cdd49855c2d190d3b4b13..20ab84ac7bff6461a5dbcf204842d674bcc60a7b 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -202,7 +202,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): vecinames = () # TODO: evaluate whether the following line would be okay with vsf.vectorized if vsf.vec_index(sf) is not None: - iname = accum_iname((accterm.argument.restriction, restriction), vsf.horizontal_width, "vec") + iname = accum_iname((accterm.argument.restriction, restriction), vsf.vector_width, "vec") vecinames = (iname,) transform(lp.tag_inames, [(iname, "vec")]) from dune.perftool.tools import maybe_wrap_subscript diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 1debd5ac9946ef29bb914c2421417672cafc4a5b..6487ce4999daed44ffd5d0617e92806c7079b791 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -512,7 +512,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def dof_shape(self): - return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.horizontal_width,) + return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.vector_width,) @property def dof_dimtags(self): diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 229d070c670b17b708c926fa15232246cbffa733..2c6f42250809b6883679024fe5c1aa2d04cb22a7 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -42,41 +42,36 @@ def no_vectorization(sumfacts): def vertical_vectorization_strategy(sumfact, depth): - # For sake of simplicity we restrict us to stage 1 so far - if sumfact.stage == 1: - # Assert that this is not already sliced - assert all(mat.slice_size is None for mat in sumfact.matrix_sequence) - - # Determine which of the matrices in the kernel should be sliced - def determine_slice_direction(): - for i, mat in enumerate(sumfact.matrix_sequence): - if mat.quadrature_size % depth == 0: - return i - elif mat.quadrature_size != 1: - raise PerftoolError("Vertical vectorization is not possible!") - - sliced = determine_slice_direction() - - kernels = [] - oldtab = sumfact.matrix_sequence[sliced] - for i in range(depth): - seq = list(sumfact.matrix_sequence) - seq[sliced] = oldtab.copy(slice_size=depth, - slice_index=i) - kernels.append(sumfact.copy(matrix_sequence=tuple(seq))) - - buffer = get_counted_variable("vertical_buffer") - input = get_counted_variable("vertical_input") - - vsf = VectorizedSumfactKernel(kernels=tuple(kernels), - buffer=buffer, - input=input, - vertical_width=depth, - ) - return _cache_vectorization_info(sumfact, vsf) - else: - return _cache_vectorization_info(sumfact, sumfact.copy(buffer=get_counted_variable("buffer"), - input=get_counted_variable("input"))) + # Assert that this is not already sliced + assert all(mat.slice_size is None for mat in sumfact.matrix_sequence) + + # Determine which of the matrices in the kernel should be sliced + def determine_slice_direction(): + for i, mat in enumerate(sumfact.matrix_sequence): + if mat.quadrature_size % depth == 0: + return i + elif mat.quadrature_size != 1: + raise PerftoolError("Vertical vectorization is not possible!") + + sliced = determine_slice_direction() + + kernels = [] + oldtab = sumfact.matrix_sequence[sliced] + for i in range(depth): + seq = list(sumfact.matrix_sequence) + seq[sliced] = oldtab.copy(slice_size=depth, + slice_index=i) + kernels.append(sumfact.copy(matrix_sequence=tuple(seq))) + + buffer = get_counted_variable("vertical_buffer") + input = get_counted_variable("vertical_input") + + vsf = VectorizedSumfactKernel(kernels=tuple(kernels), + buffer=buffer, + input=input, + vertical_width=depth, + ) + return _cache_vectorization_info(sumfact, vsf) def horizontal_vectorization_strategy(sumfacts):