From e3b1849da7bcae62d56235140db95e620023ee45 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 6 Dec 2016 17:46:12 +0100
Subject: [PATCH] Combine vectorization strategies

Only 6 orders of magnitude to go!
---
 python/dune/perftool/loopy/target.py          |  1 +
 .../loopy/transformations/collect_rotate.py   | 79 ++++++++++++++++---
 python/dune/perftool/sumfact/sumfact.py       |  5 +-
 python/dune/perftool/tools.py                 | 17 ++++
 test/sumfact/poisson/poisson_3d_order1.mini   |  1 -
 test/sumfact/poisson/poisson_3d_order2.mini   |  1 -
 6 files changed, 91 insertions(+), 13 deletions(-)

diff --git a/python/dune/perftool/loopy/target.py b/python/dune/perftool/loopy/target.py
index 055d06b4..3bb99b44 100644
--- a/python/dune/perftool/loopy/target.py
+++ b/python/dune/perftool/loopy/target.py
@@ -31,6 +31,7 @@ import cgen
 def _type_to_op_counter_type(name):
     return "oc::OpCounter<{}>".format(name)
 
+
 @pt.memoize
 def numpy_to_cpp_dtype(key):
     _registry = {'float32': 'float',
diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index c4e901e4..b011bf8b 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -1,26 +1,39 @@
 """ A kernel transformation that precomputes quantities until a vector register
 is filled and then does vector computations """
 
-from dune.perftool.loopy.vcl import get_vcl_type_size
+from dune.perftool.generation import (function_mangler,
+                                      include_file,
+                                      )
+from dune.perftool.loopy.vcl import get_vcl_type, get_vcl_type_size
 from dune.perftool.loopy.transformations.vectorview import (add_temporary_with_vector_view,
                                                             add_vector_view,
                                                             get_vector_view_name,
                                                             )
+from dune.perftool.loopy.symbolic import substitute
 from dune.perftool.sumfact.quadrature import quadrature_inames
-from dune.perftool.tools import get_pymbolic_basename
+from dune.perftool.tools import get_pymbolic_basename, get_pymbolic_tag
 
 from loopy.kernel.creation import parse_domains
 from loopy.symbolic import pw_aff_to_expr
 from loopy.match import Tagged
 
 from pymbolic.mapper.dependency import DependencyMapper
-from pymbolic.mapper.substitutor import substitute
 
 import pymbolic.primitives as prim
 import loopy as lp
 import numpy as np
 
 
+@function_mangler
+def rotate_function_mangler(knl, func, arg_dtypes):
+    if func == "transpose_reg":
+        # This is not 100% within the loopy philosoph, as we are
+        # passing the vector registers as references and have them
+        # changed. Loopy assumes this function to be read-only.
+        vcl = lp.types.NumpyType(get_vcl_type(np.float64, register_size=256))
+        return lp.CallMangleInfo("transpose_reg", (), (vcl, vcl, vcl, vcl))
+
+
 def collect_vector_data_rotate(knl):
     #
     # Process/Assert/Standardize the input
@@ -48,6 +61,7 @@ def collect_vector_data_rotate(knl):
 
     new_insns = []
     all_writers = []
+    rotating = False
 
     #
     # Inspect the given instructions for dependent quantities
@@ -123,11 +137,38 @@ def collect_vector_data_rotate(knl):
                 else:
                     raise NotImplementedError
         elif quantity in knl.temporary_variables:
-            # Add a vector view to this quantity
-            knl = add_vector_view(knl, quantity)
-            replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
-                                                  (prim.Variable("vec_index"), prim.Variable(new_iname)),
-                                                  )
+            if all(get_pymbolic_tag(expr) == 'vector' for expr in quantities[quantity]):
+                #
+                # There is a vector quantity to be vectorized! That requires register rotation!
+                #
+
+                # 1. Rotating the input data
+                knl = add_vector_view(knl, quantity)
+                include_file("dune/perftool/sumfact/transposereg.hh", filetag="operatorfile")
+                new_insns.append(lp.CallInstruction((),  # assignees
+                                                    prim.Call(prim.Variable("transpose_reg"),
+                                                              tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
+                                                    depends_on=frozenset({'continue_stmt'}),
+                                                    within_inames=common_inames.union(inames).union(frozenset({new_iname})),
+                                                    within_inames_is_final=True,
+                                                    id="{}_rotate".format(quantity),
+                                                    ))
+
+                # Add substitution rules
+                for expr in quantities[quantity]:
+                    rotating = True
+                    assert isinstance(expr, prim.Subscript)
+                    last_index = expr.index[-1]
+                    assert last_index in tuple(range(4))
+                    replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
+                                                          (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)),
+                                                          )
+            else:
+                # Add a vector view to this quantity
+                knl = add_vector_view(knl, quantity)
+                replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
+                                                      (prim.Variable("vec_index"), prim.Variable(new_iname)),
+                                                      )
 
     other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns + new_insns]]
 
@@ -149,7 +190,7 @@ def collect_vector_data_rotate(knl):
                                    id="assign_vec_index",
                                    ))
     new_insns.append(lp.Assignment(prim.Variable("vec_index"),  # assignee
-                                   prim.Sum((prim.Variable("vec_index"), 1)),  # expression
+                                   prim.Sum((prim.Variable("vec_index"), vec_size if rotating else 1)),  # expression
                                    within_inames=common_inames.union(inames),
                                    within_inames_is_final=True,
                                    depends_on=frozenset({Tagged("vec_write"), "assign_vec_index"}),
@@ -210,8 +251,15 @@ def collect_vector_data_rotate(knl):
         knl = add_vector_view(knl, lhsname)
         lhsname = get_vector_view_name(lhsname)
 
+        if rotating:
+            assert isinstance(insn.assignee, prim.Subscript)
+            last_index = insn.assignee.index[-1]
+            assert last_index in tuple(range(4))
+        else:
+            last_index = 0
+
         new_insns.append(lp.Assignment(prim.Subscript(prim.Variable(lhsname),
-                                                      (prim.Variable("vec_index"), prim.Variable(new_iname)),
+                                                      (prim.Variable("vec_index") + last_index, prim.Variable(new_iname)),
                                                       ),
                                        substitute(insn.expression, replacemap_vec),
                                        depends_on=frozenset({"continue_stmt"}),
@@ -223,4 +271,15 @@ def collect_vector_data_rotate(knl):
                                        )
                          )
 
+    # Rotate back!
+    if rotating:
+        new_insns.append(lp.CallInstruction((),  # assignees
+                                            prim.Call(prim.Variable("transpose_reg"),
+                                                      tuple(prim.Subscript(prim.Variable(lhsname), (prim.Variable("vec_index") + i, prim.Variable(new_iname))) for i in range(4))),
+                                            depends_on=frozenset({Tagged("vec_write")}),
+                                            within_inames=common_inames.union(inames).union(frozenset({new_iname})),
+                                            within_inames_is_final=True,
+                                            id="{}_rotateback".format(lhsname),
+                                            ))
+
     return knl.copy(instructions=new_insns + other_insns)
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 57d25588..d6f11c77 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -344,4 +344,7 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add
                                )
     silenced_warning('read_no_write({})'.format(out))
 
-    return prim.Variable(out), insn_dep
+    if isinstance(next(iter(a_matrices)), LargeAMatrix):
+        return lp.TaggedVariable(out, "vector"), insn_dep
+    else:
+        return prim.Variable(out), insn_dep
diff --git a/python/dune/perftool/tools.py b/python/dune/perftool/tools.py
index e9619b83..e6259e44 100644
--- a/python/dune/perftool/tools.py
+++ b/python/dune/perftool/tools.py
@@ -1,5 +1,7 @@
 """ Some grabbag tools """
+from __future__ import absolute_import
 
+import loopy as lp
 import pymbolic.primitives as prim
 
 
@@ -42,3 +44,18 @@ def maybe_wrap_subscript(expr, indices):
             return prim.Subscript(expr, indices)
     else:
         return expr
+
+
+def get_pymbolic_tag(expr):
+    assert isinstance(expr, prim.Expression)
+
+    if isinstance(expr, lp.TaggedVariable):
+        return expr.tag
+
+    if isinstance(expr, prim.Variable):
+        return None
+
+    if isinstance(expr, prim.Subscript):
+        return get_pymbolic_tag(expr.aggregate)
+
+    raise NotImplementedError("Cannot determine tag on {}".format(expr))
diff --git a/test/sumfact/poisson/poisson_3d_order1.mini b/test/sumfact/poisson/poisson_3d_order1.mini
index a636bda7..d794e0dc 100644
--- a/test/sumfact/poisson/poisson_3d_order1.mini
+++ b/test/sumfact/poisson/poisson_3d_order1.mini
@@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix}
 diff_suffix = numdiff, symdiff | expand num
 vecq_suffix = quadvec, nonquadvec | expand vec_q
 vecg_suffix = gradvec, nongradvec | expand vec_g
-{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude
 
 cells = 8 8 8
 extension = 1. 1. 1.
diff --git a/test/sumfact/poisson/poisson_3d_order2.mini b/test/sumfact/poisson/poisson_3d_order2.mini
index 24ec010e..c81eae55 100644
--- a/test/sumfact/poisson/poisson_3d_order2.mini
+++ b/test/sumfact/poisson/poisson_3d_order2.mini
@@ -4,7 +4,6 @@ __exec_suffix = {diff_suffix}_{vecq_suffix}_{vecg_suffix}
 diff_suffix = numdiff, symdiff | expand num
 vecq_suffix = quadvec, nonquadvec | expand vec_q
 vecg_suffix = gradvec, nongradvec | expand vec_g
-{vecq_suffix} == quadvec and {vecg_suffix} == gradvec | exclude
 
 cells = 8 8 8
 extension = 1. 1. 1.
-- 
GitLab