diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py index c49cb86f1ac6f4ac1300a3b8c716fc30414f8c68..c7848a315a4a2b1e48b5368ea3a6416f7e3ed13b 100644 --- a/python/dune/codegen/loopy/transformations/remove_reductions.py +++ b/python/dune/codegen/loopy/transformations/remove_reductions.py @@ -4,11 +4,19 @@ import pymbolic.primitives as prim def remove_reduction(knl, match): """Removes all matching reductions and do direct accumulation in assignee instead""" - instructions = [] # Find reductions for instr in lp.find_instructions(knl, match): if isinstance(instr.expression, lp.symbolic.Reduction): + instructions = [] + depends_on = instr.depends_on + + # Depending on this instruction + depending = [] + for i in knl.instructions: + if instr.id in i.depends_on: + depending.append(i.id) + # Remove the instruction from the kernel knl = lp.remove_instructions(knl, set([instr.id])) @@ -30,11 +38,15 @@ def remove_reduction(knl, match): expression, within_inames=within_inames, id=id_accum, - depends_on=frozenset((id_zero,)), + depends_on=frozenset((id_zero,) + tuple(depends_on)), tags=('assignement',))) - knl = knl.copy(instructions=knl.instructions + instructions) + knl = knl.copy(instructions=knl.instructions + instructions) + + for dep in depending: + match = lp.match.Id(dep) + knl = lp.add_dependency(knl, match, id_accum) return knl