From 4b1eb0bc42ed0205be48ffa54c2cade677ed6e13 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 21 Nov 2016 13:31:05 +0100
Subject: [PATCH] Finalize vectorization of quadrature loop

---
 .../loopy/transformations/collect_rotate.py   | 101 +++++++++++-------
 1 file changed, 61 insertions(+), 40 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index f7184efd..95d71261 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -53,6 +53,8 @@ def collect_vector_data_rotate(knl, insns, inames):
     knl = knl.copy(domains=knl.domains + domain)
     knl = lp.tag_inames(knl, [(new_iname, "vec")])
 
+    new_insns = []
+
     #
     # Inspect the given instructions for dependent quantities
     #
@@ -70,51 +72,61 @@ def collect_vector_data_rotate(knl, insns, inames):
     replacemap_vec = {}
     for quantity in quantities:
         expr, = quantities[quantity]
-        arrname = quantity + '_buffered_arr'
-        knl = add_temporary_with_vector_view(knl,
-                                             arrname,
-                                             dtype=np.float64,
-                                             shape=(vec_size,),
-                                             dim_tags="c",
-                                             base_storage=quantity + '_base_storage',
-                                             scope=lp.temp_var_scope.PRIVATE,
-                                             )
-
-        replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
-        replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),))
-
-    write_match = lp.match.Or(tuple(lp.match.Writes(q) for q in quantities))
-    iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
-    match = lp.match.And((write_match, iname_match))
-    write_insns = lp.find_instructions(knl, match)
-
-    other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + write_insns]]
-    new_insns = []
-    temporaries = knl.temporary_variables
 
-    for insn in write_insns:
-        if isinstance(insn, lp.Assignment):
-            new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
-                                       depends_on_is_final=True,
-                                       )
-                             )
-        elif isinstance(insn, lp.CInstruction):
-            # Rip apart the code and change the assignee
-            assignee, expression = insn.code.split("=")
-            assignee = assignee.strip()
-            assert assignee in replacemap_arr
-
-            code = "{} ={}".format(str(replacemap_arr[assignee]), expression)
-            new_insns.append(insn.copy(code=code,
-                                       depends_on_is_final=True,
-                                       ))
+        # Check whether there is an instruction that writes this quantity within
+        # the given inames. If so, we need a buffer array.
+        iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
+        write_match = lp.match.Writes(quantity)
+        match = lp.match.And((iname_match, write_match))
+        write_insns = lp.find_instructions(knl, match)
+
+        if write_insns:
+            arrname = quantity + '_buffered_arr'
+            knl = add_temporary_with_vector_view(knl,
+                                                 arrname,
+                                                 dtype=np.float64,
+                                                 shape=(vec_size,),
+                                                 dim_tags="c",
+                                                 base_storage=quantity + '_base_storage',
+                                                 scope=lp.temp_var_scope.PRIVATE,
+                                                 )
+
+            replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
+            replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),))
+
+            for insn in write_insns:
+                if isinstance(insn, lp.Assignment):
+                    new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
+                                               depends_on_is_final=True,
+                                               )
+                                 )
+                elif isinstance(insn, lp.CInstruction):
+                    # Rip apart the code and change the assignee
+                    assignee, expression = insn.code.split("=")
+                    assignee = assignee.strip()
+                    assert assignee in replacemap_arr
+
+                    code = "{} ={}".format(str(replacemap_arr[assignee]), expression)
+                    new_insns.append(insn.copy(code=code,
+                                               depends_on_is_final=True,
+                                               ))
+                else:
+                    raise NotImplementedError
         else:
-            raise NotImplementedError
+            # Add a vector view to this quantity
+            knl = add_vector_view(knl, quantity)
+            replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
+                                                  (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)),
+                                                  )
+
+    other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
 
     #
     # Add two counter variables to the kernel
     #
 
+    temporaries = knl.temporary_variables
+
     # Insert a flat consecutive counter 'total_index'
     temporaries['total_index'] = lp.TemporaryVariable('total_index',  # name
                                                       dtype=np.int32,
@@ -155,6 +167,13 @@ def collect_vector_data_rotate(knl, insns, inames):
                                    id="update_rotate_index",
                                    ))
 
+    knl = knl.copy(temporary_variables=temporaries)
+
+    #
+    # Add a continue statement depending on the rotate index
+    #
+
+
     # Determine the condition for the continue statement
     upper_bound = prim.Product(tuple(pw_aff_to_expr(knl.get_iname_bounds(i).size) for i in inames))
     total_check = prim.Comparison(prim.Variable("total_index"), "<", upper_bound)
@@ -182,11 +201,13 @@ def collect_vector_data_rotate(knl, insns, inames):
         knl = add_vector_view(knl, lhsname)
         lhsname = get_vector_view_name(lhsname)
 
-        new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname), (prim.FloorDiv(prim.Variable("total_index"), vec_size), prim.Variable(new_iname))),
+        new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname),
+                                                      (prim.Sum((prim.FloorDiv(prim.Variable("total_index"), vec_size), -1)), prim.Variable(new_iname)),
+                                                      ),
                                        substitute(insn.expression, replacemap_vec),
                                        depends_on=frozenset({"continue_stmt"}),
                                        depends_on_is_final=True,
-                                       within_inames=frozenset(inames + (new_iname,)),
+                                       within_inames=common_inames.union(frozenset(inames + (new_iname,))),
                                        within_inames_is_final=True,
                                        id=insn.id,
                                        )
-- 
GitLab