diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index 34d972a34cf1942542169c18a6f9beada59b3195..e6f1fe112401d5eebcfc5196a8cd6b0e29a57307 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -312,12 +312,12 @@ def collect_vector_data_rotate(knl): if rotating: assert isinstance(insn.assignee, prim.Subscript) - last_index = insn.assignee.index[-1] - assert last_index in tuple(range(4)) tag = get_pymbolic_tag(insn.assignee) if tag is None: print insn.assignee horizontal, vertical = tuple(int(i) for i in re.match("vecsumfac_h(.*)_v(.*)", tag).groups()) + last_index = insn.assignee.index[-1] + assert last_index in tuple(range(horizontal * vertical)) else: last_index = 0 horizontal = 1 diff --git a/python/dune/perftool/loopy/vcl.py b/python/dune/perftool/loopy/vcl.py index 1ec6a0e30efa489f0b385eec81cea7c61a62f8e7..daa15681f67b5b1311c36e9f95f5576764fb1f8c 100644 --- a/python/dune/perftool/loopy/vcl.py +++ b/python/dune/perftool/loopy/vcl.py @@ -56,6 +56,11 @@ def get_vcl_type(nptype, register_size=None, vector_width=None): return VCLTypeRegistry.types[np.dtype(nptype), vector_width] +def get_vcl_typename(nptype, register_size=None, vector_width=None): + vcltype = get_vcl_type(nptype, register_size=register_size, vector_width=vector_width) + return VCLTypeRegistry.names[vcltype] + + @function_mangler def vcl_function_mangler(knl, func, arg_dtypes): if func == "mul_add": diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 40547609ac26af4832e4ac936875bfea7ad88f6d..e3ad6541d5c50d3b8f63f59e0323d858dc26a332 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -27,6 +27,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, ) from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname +from dune.perftool.loopy.vcl import get_vcl_typename import loopy as lp import numpy as np @@ -132,7 +133,7 @@ def _realize_sum_factorization_kernel(sf): out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) vec_iname = () if matrix.vectorized: - iname = sumfact_iname(4, "vec") + iname = sumfact_iname(sf.vector_width, "vec") vec_iname = (prim.Variable(iname),) transform(lp.tag_inames, [(iname, "vec")]) @@ -159,7 +160,7 @@ def _realize_sum_factorization_kernel(sf): globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags) if matrix.vectorized: - input_summand = prim.Call(prim.Variable("Vec4d"), + input_summand = prim.Call(prim.Variable(get_vcl_typename(np.float64, vector_width=sf.vector_width)), (prim.Subscript(prim.Variable(direct_input), input_inames),)) else: