diff --git a/python/dune/perftool/loopy/collectvector.py b/python/dune/perftool/loopy/collectvector.py index 37377d311fc394d8e0143cda01a00906d3c07dcb..b5023e0e8fce7371c8a79d2c586cc5d64bc4bb4e 100644 --- a/python/dune/perftool/loopy/collectvector.py +++ b/python/dune/perftool/loopy/collectvector.py @@ -154,6 +154,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4): # Pre-evaluate all the needed quantities replacemap_arr = {} + replacemap_poi = {} replacemap_vec = {} for quantity, quantity_exprs in quantities.items(): # TODO for now I only consider the case where an array occurs but once! @@ -181,6 +182,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4): ) replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),)) + replacemap_poi[quantity] = prim.Variable(arrname) replacemap_vec[quantity_expr] = prim.Variable(vecname) for insn in write_insns: @@ -189,9 +191,13 @@ def collect_vector_data(knl, insns, inames, vec_size=4): ) ) if isinstance(insn, lp.CInstruction): - # TODO: What do we do about CInstructions? - # Example: detjac = ... - new_insns.append(insn) + # Rip apart the code and change the assignee + assignee, expression = insn.code.split("=") + assignee = assignee.strip() + assert assignee in replacemap_arr + + code = "{} ={}".format(str(replacemap_arr[assignee]), expression) + new_insns.append(insn.copy(code=code)) # Determine the condition for the continue statement upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames)) @@ -222,7 +228,7 @@ def collect_vector_data(knl, insns, inames, vec_size=4): for expr in depmapper(insn.expression): name = get_pymbolic_basename(expr) new_insns.append(load_instruction(replacemap_vec[expr], - replacemap_vec[expr], + replacemap_poi[name], depends_on=frozenset({"continue_stmt"}), within_inames=inames, within_inames_is_final=True,