From 38574e66d6d1c5559e9a2164a46c83a3fd4e5ee8 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 7 Dec 2016 09:49:53 +0100
Subject: [PATCH] Reintroduce total_index

---
 .../loopy/transformations/collect_rotate.py   | 27 +++++++++++++++----
 1 file changed, 22 insertions(+), 5 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index 6e9627a0..f585c79e 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -173,11 +173,29 @@ def collect_vector_data_rotate(knl):
     other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
 
     #
-    # Add two counter variables to the kernel
+    # Add three counter variables to the kernel
     #
-
     temporaries = knl.temporary_variables
 
+    temporaries['total_index'] = lp.TemporaryVariable('total_index',
+                                                      dtype=np.int32,
+                                                      scope=lp.temp_var_scope.PRIVATE,
+                                                      )
+    new_insns.append(lp.Assignment(prim.Variable("total_index"),  # assignee
+                                   0,  # expression
+                                   within_inames=common_inames,
+                                   within_inames_is_final=True,
+                                   id="assign_total_index",
+                                   ))
+    new_insns.append(lp.Assignment(prim.Variable("total_index"),  # assignee
+                                   prim.Sum((prim.Variable("total_index"), 1)),  # expression
+                                   within_inames=common_inames.union(inames),
+                                   within_inames_is_final=True,
+                                   depends_on=frozenset(all_writers).union(frozenset({"assign_total_index"})),
+                                   depends_on_is_final=True,
+                                   id="update_total_index",
+                                   ))
+
     # Insert a flat consecutive counter 'vec_index', which is increased after a vector chunk is handled
     temporaries['vec_index'] = lp.TemporaryVariable('vec_index',  # name
                                                     dtype=np.int32,
@@ -226,8 +244,7 @@ def collect_vector_data_rotate(knl):
 
     # Determine the condition for the continue statement
     upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
-    stride = 1 if rotating else vec_size
-    total_check = prim.Comparison(stride * prim.Variable("vec_index") + prim.Variable("rotate_index"), "<", upper_bound)
+    total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
     rotate_check = prim.Comparison(prim.Variable("rotate_index"), "!=", 0)
     check = prim.LogicalAnd((rotate_check, total_check))
 
@@ -235,7 +252,7 @@ def collect_vector_data_rotate(knl):
     new_insns.append(lp.CInstruction((),  # iname exprs that the code needs access to
                                      "continue;",  # the code
                                      predicates=frozenset({check}),
-                                     depends_on=frozenset({"update_rotate_index"}).union(frozenset(all_writers)),
+                                     depends_on=frozenset({"update_rotate_index", "update_total_index"}).union(frozenset(all_writers)),
                                      depends_on_is_final=True,
                                      within_inames=common_inames.union(inames),
                                      within_inames_is_final=True,
-- 
GitLab