From 4b1eb0bc42ed0205be48ffa54c2cade677ed6e13 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 21 Nov 2016 13:31:05 +0100 Subject: [PATCH] Finalize vectorization of quadrature loop --- .../loopy/transformations/collect_rotate.py | 101 +++++++++++------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py index f7184efd..95d71261 100644 --- a/python/dune/perftool/loopy/transformations/collect_rotate.py +++ b/python/dune/perftool/loopy/transformations/collect_rotate.py @@ -53,6 +53,8 @@ def collect_vector_data_rotate(knl, insns, inames): knl = knl.copy(domains=knl.domains + domain) knl = lp.tag_inames(knl, [(new_iname, "vec")]) + new_insns = [] + # # Inspect the given instructions for dependent quantities # @@ -70,51 +72,61 @@ def collect_vector_data_rotate(knl, insns, inames): replacemap_vec = {} for quantity in quantities: expr, = quantities[quantity] - arrname = quantity + '_buffered_arr' - knl = add_temporary_with_vector_view(knl, - arrname, - dtype=np.float64, - shape=(vec_size,), - dim_tags="c", - base_storage=quantity + '_base_storage', - scope=lp.temp_var_scope.PRIVATE, - ) - - replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),)) - replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),)) - - write_match = lp.match.Or(tuple(lp.match.Writes(q) for q in quantities)) - iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames)) - match = lp.match.And((write_match, iname_match)) - write_insns = lp.find_instructions(knl, match) - - other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + write_insns]] - new_insns = [] - temporaries = knl.temporary_variables - for insn in write_insns: - if isinstance(insn, lp.Assignment): - new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)], - depends_on_is_final=True, - ) - ) - elif isinstance(insn, lp.CInstruction): - # Rip apart the code and change the assignee - assignee, expression = insn.code.split("=") - assignee = assignee.strip() - assert assignee in replacemap_arr - - code = "{} ={}".format(str(replacemap_arr[assignee]), expression) - new_insns.append(insn.copy(code=code, - depends_on_is_final=True, - )) + # Check whether there is an instruction that writes this quantity within + # the given inames. If so, we need a buffer array. + iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames)) + write_match = lp.match.Writes(quantity) + match = lp.match.And((iname_match, write_match)) + write_insns = lp.find_instructions(knl, match) + + if write_insns: + arrname = quantity + '_buffered_arr' + knl = add_temporary_with_vector_view(knl, + arrname, + dtype=np.float64, + shape=(vec_size,), + dim_tags="c", + base_storage=quantity + '_base_storage', + scope=lp.temp_var_scope.PRIVATE, + ) + + replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),)) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),)) + + for insn in write_insns: + if isinstance(insn, lp.Assignment): + new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)], + depends_on_is_final=True, + ) + ) + elif isinstance(insn, lp.CInstruction): + # Rip apart the code and change the assignee + assignee, expression = insn.code.split("=") + assignee = assignee.strip() + assert assignee in replacemap_arr + + code = "{} ={}".format(str(replacemap_arr[assignee]), expression) + new_insns.append(insn.copy(code=code, + depends_on_is_final=True, + )) + else: + raise NotImplementedError else: - raise NotImplementedError + # Add a vector view to this quantity + knl = add_vector_view(knl, quantity) + replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)), + (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)), + ) + + other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]] # # Add two counter variables to the kernel # + temporaries = knl.temporary_variables + # Insert a flat consecutive counter 'total_index' temporaries['total_index'] = lp.TemporaryVariable('total_index', # name dtype=np.int32, @@ -155,6 +167,13 @@ def collect_vector_data_rotate(knl, insns, inames): id="update_rotate_index", )) + knl = knl.copy(temporary_variables=temporaries) + + # + # Add a continue statement depending on the rotate index + # + + # Determine the condition for the continue statement upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound) @@ -182,11 +201,13 @@ def collect_vector_data_rotate(knl, insns, inames): knl = add_vector_view(knl, lhsname) lhsname = get_vector_view_name(lhsname) - new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), (prim.FloorDiv(prim.Variable("total_index"), vec_size), prim.Variable(new_iname))), + new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), + (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)), + ), substitute(insn.expression, replacemap_vec), depends_on=frozenset({"continue_stmt"}), depends_on_is_final=True, - within_inames=frozenset(inames + (new_iname,)), + within_inames=common_inames.union(frozenset(inames + (new_iname,))), within_inames_is_final=True, id=insn.id, ) -- GitLab