Skip to content
Snippets Groups Projects
Commit 38574e66 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reintroduce total_index

parent fdc0f05b
No related branches found
No related tags found
No related merge requests found
......@@ -173,11 +173,29 @@ def collect_vector_data_rotate(knl):
other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
#
# Add two counter variables to the kernel
# Add three counter variables to the kernel
#
temporaries = knl.temporary_variables
temporaries['total_index'] = lp.TemporaryVariable('total_index',
dtype=np.int32,
scope=lp.temp_var_scope.PRIVATE,
)
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
0, # expression
within_inames=common_inames,
within_inames_is_final=True,
id="assign_total_index",
))
new_insns.append(lp.Assignment(prim.Variable("total_index"), # assignee
prim.Sum((prim.Variable("total_index"), 1)), # expression
within_inames=common_inames.union(inames),
within_inames_is_final=True,
depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})),
depends_on_is_final=True,
id="update_total_index",
))
# Insert a flat consecutive counter 'vec_index', which is increased after a vector chunk is handled
temporaries['vec_index'] = lp.TemporaryVariable('vec_index', # name
dtype=np.int32,
......@@ -226,8 +244,7 @@ def collect_vector_data_rotate(knl):
# Determine the condition for the continue statement
upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
stride = 1 if rotating else vec_size
total_check = prim.Comparison(stride * prim.Variable("vec_index") + prim.Variable("rotate_index"), "<", upper_bound)
total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0)
check = prim.LogicalAnd((rotate_check, total_check))
......@@ -235,7 +252,7 @@ def collect_vector_data_rotate(knl):
new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to
"continue;", # the code
predicates=frozenset({check}),
depends_on=frozenset({"update_rotate_index"}).union(frozenset(all_writers)),
depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)),
depends_on_is_final=True,
within_inames=common_inames.union(inames),
within_inames_is_final=True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment