Skip to content
Snippets Groups Projects
Commit 718ed05c authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Toy example works

parent 0c2f06d3
No related branches found
No related tags found
No related merge requests found
......@@ -46,6 +46,13 @@ def collect_vector_data_rotate(knl, insns, inames):
# TODO infer the numpy type here
vec_size = get_vcl_type_size(np.float64)
# Add an iname to the kernel which will be used for vectorization
new_iname = "quad_vec_{}".format("_".join(inames))
domain = "{{ [{0}] : 0<={0}<{1} }}".format(new_iname, str(vec_size))
domain = parse_domains(domain, {})
knl = knl.copy(domains=knl.domains + domain)
knl = lp.tag_inames(knl, [(new_iname, "vec")])
#
# Inspect the given instructions for dependent quantities
#
......@@ -73,7 +80,7 @@ def collect_vector_data_rotate(knl, insns, inames):
)
replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
replacemap_vec[expr] = prim.Variable(get_vector_view_name(arrname))
replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),))
write_match = lp.match.Or(tuple(lp.match.Writes(q) for q in quantities))
iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
......@@ -136,49 +143,38 @@ def collect_vector_data_rotate(knl, insns, inames):
id="update_rotate_index",
))
# Determine the condition for the continue statement
upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0)
check = prim.LogicalAnd((rotate_check, total_check))
# Insert the 'continue' statement
new_insns.append(lp.CInstruction((), # iname exprs that the code needs access to
"continue;", # the code
predicates=frozenset({check}),
depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset([i.id for i in write_insns])),
depends_on_is_final=True,
within_inames=common_inames.union(inames),
within_inames_is_final=True,
id="continue_stmt",
))
#
# Construct a flat loop for the given instructions
# Reconstruct the compute instructions
#
# new_insns = []
# other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns]]
#
# size = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
# size = prim.FloorDiv(size, vec_size)
#
# temporaries = knl.temporary_variables
# temporaries["flatsize"] = lp.TemporaryVariable("flatsize",
# dtype=np.int32,
# shape=(),
# )
# new_insns.append(lp.Assignment(prim.Variable("flatsize"),
# size,
# )
# )
#
# # Add an additional domain to the kernel
# new_iname = "flat_{}".format("_".join(inames))
# domain = "{{ [{0}] : 0<={0}<flatsize }}".format(new_iname, str(size))
# domain = parse_domains(domain, {})
# knl = knl.copy(domains=knl.domains + domain,
# temporary_variables=temporaries)
#
# # Split and tag the flat iname
# knl = lp.split_iname(knl, new_iname, vec_size, inner_tag="vec")
# new_inames = ("{}_outer".format(new_iname), "{}_inner".format(new_iname))
# knl = lp.assume(knl, "flatsize mod {} = 0".format(vec_size))
#
# for insn in insns:
# # Get a vector view of the lhs expression
# lhsname = get_pymbolic_basename(insn.assignee)
# knl = add_vector_view(knl, lhsname)
# lhsname = get_vector_view_name(lhsname)
#
# new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), tuple(prim.Variable(i) for i in new_inames)),
# prim.Subscript(prim.Variable(get_vector_view_name("wk_precomputed")), tuple(prim.Variable(i) for i in new_inames)),
# within_inames=frozenset(new_inames),
# within_inames_is_final=True,
# )
# )
for insn in insns:
# Get a vector view of the lhs expression
lhsname = get_pymbolic_basename(insn.assignee)
knl = add_vector_view(knl, lhsname)
lhsname = get_vector_view_name(lhsname)
new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), (prim.FloorDiv(prim.Variable("total_size"), vec_size), prim.Variable(new_iname))),
substitute(insn.expression, replacemap_vec),
within_inames=frozenset(inames + (new_iname,)),
within_inames_is_final=True,
)
)
return knl.copy(instructions=new_insns + other_insns)
......@@ -43,9 +43,18 @@ def add_vector_view(knl, tmpname):
shape=(size, vecsize),
base_storage=tmpname + "_base",
dtype=np.float64,
scope=lp.temp_var_scope.PRIVATE,
)
return knl.copy(temporary_variables=temporaries)
# Avoid that any of these temporaries are eliminated
silenced = ['temp_to_write({})'.format(tmpname),
'temp_to_write({})'.format(vecname),
'read_no_write({})'.format(tmpname),
'read_no_write({})'.format(vecname),
]
return knl.copy(temporary_variables=temporaries,
silenced_warnings=knl.silenced_warnings + silenced)
def add_temporary_with_vector_view(knl, name, *args, **kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment