From d8df729e25ed932c53d8bb223e0ff33cfd719c87 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Sun, 10 Feb 2019 21:35:12 +0100
Subject: [PATCH] [skip ci] Fix dependency bug in reduction removal

---
 .../loopy/transformations/remove_reductions.py | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py
index c49cb86f..c7848a31 100644
--- a/python/dune/codegen/loopy/transformations/remove_reductions.py
+++ b/python/dune/codegen/loopy/transformations/remove_reductions.py
@@ -4,11 +4,19 @@ import pymbolic.primitives as prim
 
 def remove_reduction(knl, match):
     """Removes all matching reductions and do direct accumulation in assignee instead"""
-    instructions = []
 
     # Find reductions
     for instr in lp.find_instructions(knl, match):
         if isinstance(instr.expression, lp.symbolic.Reduction):
+            instructions = []
+            depends_on = instr.depends_on
+
+            # Depending on this instruction
+            depending = []
+            for i in knl.instructions:
+                if instr.id in i.depends_on:
+                    depending.append(i.id)
+
             # Remove the instruction from the kernel
             knl = lp.remove_instructions(knl, set([instr.id]))
 
@@ -30,11 +38,15 @@ def remove_reduction(knl, match):
                                               expression,
                                               within_inames=within_inames,
                                               id=id_accum,
-                                              depends_on=frozenset((id_zero,)),
+                                              depends_on=frozenset((id_zero,) + tuple(depends_on)),
                                               tags=('assignement',)))
 
 
-    knl = knl.copy(instructions=knl.instructions + instructions)
+            knl = knl.copy(instructions=knl.instructions + instructions)
+
+            for dep in depending:
+                match = lp.match.Id(dep)
+                knl = lp.add_dependency(knl, match, id_accum)
     return knl
 
 
-- 
GitLab