From 38574e66d6d1c5559e9a2164a46c83a3fd4e5ee8 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Wed, 7 Dec 2016 09:49:53 +0100 Subject: [PATCH] Reintroduce total_index --- .../loopy/transformations/collect_rotate.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index 6e9627a0..f585c79e 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -173,11 +173,29 @@ def collect_vector_data_rotate(knl): other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] # - # Add two counter variables to the kernel + # Add three counter variables to the kernel # - temporaries = knl.temporary_variables + temporaries['total_index'] = lp.TemporaryVariable('total_index', + dtype=np.int32, + scope=lp.temp_var_scope.PRIVATE, + ) + new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee + 0, # expression + within_inames=common_inames, + within_inames_is_final=True, + id="assign_total_index", + )) + new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee + prim.Sum((prim.Variable("total_index"), 1)), # expression + within_inames=common_inames.union(inames), + within_inames_is_final=True, + depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})), + depends_on_is_final=True, + id="update_total_index", + )) + # Insert a flat consecutive counter 'vec_index', which is increased after a vector chunk is handled temporaries['vec_index'] = lp.TemporaryVariable('vec_index', # name dtype=np.int32, @@ -226,8 +244,7 @@ def collect_vector_data_rotate(knl): # Determine the condition for the continue statement upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) - stride = 1 if rotating else vec_size - total_check = prim.Comparison(stride * prim.Variable("vec_index") + prim.Variable("rotate_index"), "<", upper_bound) + total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound) rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0) check = prim.LogicalAnd((rotate_check, total_check)) @@ -235,7 +252,7 @@ def collect_vector_data_rotate(knl): new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to "continue;", # the code predicates=frozenset({check}), - depends_on=frozenset({"update_rotate_index"}).union(frozenset(all_writers)), + depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)), depends_on_is_final=True, within_inames=common_inames.union(inames), within_inames_is_final=True, -- GitLab