diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index 5839873dceb5222ffc65edb02246b70243762973..d3d35c5facb992a4ebbde963eb94b0db84cac4d3 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -10,6 +10,7 @@ from dune.perftool.tools import get_pymbolic_basename from loopy.kernel.creation import parse_domains from loopy.symbolic import pw_aff_to_expr +from loopy.match import Tagged from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.substitutor import substitute @@ -133,7 +134,7 @@ def collect_vector_data_rotate(knl, insns, inames): # Add a vector view to this quantity knl = add_vector_view(knl, quantity) replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), - (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)), + (prim.Variable("vec_index"), prim.Variable(new_iname)), ) other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] @@ -144,24 +145,24 @@ def collect_vector_data_rotate(knl, insns, inames): temporaries = knl.temporary_variables - # Insert a flat consecutive counter 'total_index' - temporaries['total_index'] = lp.TemporaryVariable('total_index', # name - dtype=np.int32, - scope=lp.temp_var_scope.PRIVATE, - ) - new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee + # Insert a flat consecutive counter 'vec_index', which is increased after a vector chunk is handled + temporaries['vec_index'] = lp.TemporaryVariable('vec_index', # name + dtype=np.int32, + scope=lp.temp_var_scope.PRIVATE, + ) + new_insns.append(lp.Assignment(prim.Variable("vec_index"), # assignee 0, # expression within_inames=common_inames, within_inames_is_final=True, - id="assign_total_index", + id="assign_vec_index", )) - new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee - prim.Sum((prim.Variable("total_index"), 1)), # expression + new_insns.append(lp.Assignment(prim.Variable("vec_index"), # assignee + prim.Sum((prim.Variable("vec_index"), 1)), # expression within_inames=common_inames.union(inames), within_inames_is_final=True, - depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})), + depends_on=frozenset({Tagged("vec_write"), "assign_vec_index"}), depends_on_is_final=True, - id="update_total_index", + id="update_vec_index", )) # Insert a rotating index, that counts 0 , .. , vecsize - 1 @@ -192,7 +193,7 @@ def collect_vector_data_rotate(knl, insns, inames): # Determine the condition for the continue statement upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) - total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound) + total_check = prim.Comparison(vec_size * prim.Variable("vec_index") + prim.Variable("rotate_index"), "<", upper_bound) rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0) check = prim.LogicalAnd((rotate_check, total_check)) @@ -200,7 +201,7 @@ def collect_vector_data_rotate(knl, insns, inames): new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to "continue;", # the code predicates=frozenset({check}), - depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)), + depends_on=frozenset({"update_rotate_index"}).union(frozenset(all_writers)), depends_on_is_final=True, within_inames=common_inames.union(inames), within_inames_is_final=True, @@ -218,7 +219,7 @@ def collect_vector_data_rotate(knl, insns, inames): lhsname = get_vector_view_name(lhsname) new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), - (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)), + (prim.Variable("vec_index"), prim.Variable(new_iname)), ), substitute(insn.expression, replacemap_vec), depends_on=frozenset({"continue_stmt"}), @@ -226,6 +227,7 @@ def collect_vector_data_rotate(knl, insns, inames): within_inames=common_inames.union(frozenset(inames + (new_iname,))), within_inames_is_final=True, id=insn.id, + tags=frozenset({"vec_write"}) ) )