From d2010bf5fd96d24d7acc2329432d3c7bf401ccf6 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 18 Jan 2017 11:24:51 +0100
Subject: [PATCH] Implement quadrature loop vectorization with temporaries of
 arbitrary shape

---
 .../loopy/transformations/collect_rotate.py   |  4 ++--
 .../loopy/transformations/vectorview.py       | 22 ++++++++++++++++---
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index 9004b2f2..d633843a 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -167,7 +167,7 @@ def collect_vector_data_rotate(knl):
                 #
 
                 # 1. Rotating the input data
-                knl = add_vector_view(knl, quantity)
+                knl = add_vector_view(knl, quantity, flatview=True)
                 include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile")
                 new_insns.append(lp.CallInstruction((),  # assignees
                                                     prim.Call(prim.Variable("transpose_reg"),
@@ -289,7 +289,7 @@ def collect_vector_data_rotate(knl):
     for insn in insns:
         # Get a vector view of the lhs expression
         lhsname = get_pymbolic_basename(insn.assignee)
-        knl = add_vector_view(knl, lhsname, pad_to=vec_size)
+        knl = add_vector_view(knl, lhsname, pad_to=vec_size, flatview=True)
         lhsname = get_vector_view_name(lhsname)
 
         if rotating:
diff --git a/python/dune/perftool/loopy/transformations/vectorview.py b/python/dune/perftool/loopy/transformations/vectorview.py
index e0d78e14..1450b26e 100644
--- a/python/dune/perftool/loopy/transformations/vectorview.py
+++ b/python/dune/perftool/loopy/transformations/vectorview.py
@@ -16,7 +16,7 @@ def get_vector_view_name(tmpname):
     return tmpname + "_vec"
 
 
-def add_vector_view(knl, tmpname, pad_to=None):
+def add_vector_view(knl, tmpname, pad_to=None, flatview=False):
     """
     Kernel transformation to add a vector view temporary
     that interprets the same memory as another temporary
@@ -53,11 +53,27 @@ def add_vector_view(knl, tmpname, pad_to=None):
     if pad_to:
         size = (size // pad_to + 1) * pad_to
 
+    # Some vectorview are intentionally flat! (e.g. the output buffers of
+    # sum factorization kernels
+    if flatview:
+        shape = (size, vecsize)
+        dim_tags = "c,vec"
+    else:
+        assert(temp.shape[-1] == vecsize)
+        shape = temp.shape
+        # This works around a loopy weirdness (which might as well be a bug)
+        # TODO: investigate this!
+        if len(shape) == 1:
+            shape = (1, vecsize)
+            dim_tags = "c,vec"
+        else:
+            dim_tags = temp.dim_tags[:-1] + ("vec",)
+
     # Now add a vector view temporary
     vecname = tmpname + "_vec"
     temporaries[vecname] = lp.TemporaryVariable(vecname,
-                                                dim_tags="c,vec",
-                                                shape=(size, vecsize),
+                                                dim_tags=dim_tags,
+                                                shape=shape,
                                                 base_storage=bsname,
                                                 dtype=np.float64,
                                                 scope=lp.temp_var_scope.PRIVATE,
-- 
GitLab