From 95dec7b28284c0ef05830bf7f382d4ba91be0db0 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 28 Jul 2017 14:14:44 +0200
Subject: [PATCH] Fix quadrature loop vectorization for multiple quadrature
 loops

---
 .../loopy/transformations/vectorize_quad.py   | 55 +++++++++++++------
 test/sumfact/stokes/stokes.mini               |  4 +-
 test/sumfact/stokes/stokes_dg.mini            |  5 +-
 3 files changed, 44 insertions(+), 20 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/vectorize_quad.py b/python/dune/perftool/loopy/transformations/vectorize_quad.py
index af20c87d..cde31e2d 100644
--- a/python/dune/perftool/loopy/transformations/vectorize_quad.py
+++ b/python/dune/perftool/loopy/transformations/vectorize_quad.py
@@ -56,11 +56,12 @@ def rotate_function_mangler(knl, func, arg_dtypes):
 
 
 class VectorIndices(object):
-    def __init__(self):
+    def __init__(self, suffix):
+        self.suffix = suffix
         self.needed = set()
 
     def get(self, increment):
-        name = "vec_index_inc{}".format(increment)
+        name = "vec_index_inc{}{}".format(increment, self.suffix)
         self.needed.add((name, increment))
         return prim.Variable(name)
 
@@ -82,15 +83,18 @@ class AntiPatternRemover(IdentityMapper):
         return IdentityMapper.map_floor_div(self, expr)
 
 
-def vectorize_quadrature_loop(knl):
+def _vectorize_quadrature_loop(knl, inames, suffix):
     #
     # Process/Assert/Standardize the input
     #
 
-    insns = [i for i in lp.find_instructions(knl, lp.match.Tagged("quadvec"))]
+    # Construct a match filter for the instructions to handle
+    tag = lp.match.Tagged("quadvec")
+    within = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
+    cond = lp.match.And((tag, within))
+    insns = [i for i in lp.find_instructions(knl, cond)]
     if not insns:
         return knl
-    inames = quadrature_inames()
 
     # Analyse the inames of the given instructions and identify inames
     # that they all have in common. Those inames will also be iname dependencies
@@ -100,7 +104,7 @@ def vectorize_quadrature_loop(knl):
     # Determine the vector lane width
     # TODO infer the numpy type here
     vec_size = get_vcl_type_size(np.float64)
-    vector_indices = VectorIndices()
+    vector_indices = VectorIndices(suffix)
 
     #
     # Inspect the given instructions for dependent quantities
@@ -154,17 +158,17 @@ def vectorize_quadrature_loop(knl):
     size = ceildiv(size, vec_size)
 
     # Add an additional domain to the kernel
-    outer_iname = "flat_{}".format("_".join(inames))
+    outer_iname = "flat_{}{}".format("_".join(inames), suffix)
     o_domain = "{{ [{0}] : 0<={0}<{1} }}".format(outer_iname, size)
     o_domain = parse_domains(o_domain, {})
-    vec_iname = "vec_{}".format("_".join(inames))
+    vec_iname = "vec_{}{}".format("_".join(inames), suffix)
     i_domain = "{{ [{0}] : 0<={0}<{1} }}".format(vec_iname, vec_size)
     i_domain = parse_domains(i_domain, {})
     knl = knl.copy(domains=knl.domains + o_domain + i_domain)
     knl = lp.tag_inames(knl, [(vec_iname, "vec")])
 
     # Update instruction lists
-    insns = [i for i in lp.find_instructions(knl, lp.match.Tagged("quadvec"))]
+    insns = [i for i in lp.find_instructions(knl, cond)]
     other_insns = [i for i in knl.instructions if i.id not in [j.id for j in insns]]
     quantities = {}
     for insn in insns:
@@ -196,10 +200,10 @@ def vectorize_quadrature_loop(knl):
                                                                   tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
                                                                                        (vector_indices.get(horizontal) + i, prim.Variable(vec_iname)))
                                                                         for i in range(horizontal))),
-                                                        depends_on=frozenset({'continue_stmt'}),
+                                                        depends_on=frozenset({'continue_stmt{}'.format(suffix)}),
                                                         within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
                                                         within_inames_is_final=True,
-                                                        id="{}_rotate".format(quantity),
+                                                        id="{}_rotate{}".format(quantity, suffix),
                                                         ))
 
                 # Add substitution rules
@@ -257,12 +261,12 @@ def vectorize_quadrature_loop(knl):
                                                       (vector_indices.get(horizontal) + last_index, prim.Variable(vec_iname)),
                                                       ),
                                        substitute(insn.expression, replacemap),
-                                       depends_on=frozenset({"continue_stmt"}),
+                                       depends_on=frozenset({"continue_stmt{}".format(suffix)}),
                                        depends_on_is_final=True,
                                        within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
                                        within_inames_is_final=True,
                                        id=insn.id,
-                                       tags=frozenset({"vec_write"})
+                                       tags=frozenset({"vec_write{}".format(suffix)})
                                        )
                          )
 
@@ -273,10 +277,10 @@ def vectorize_quadrature_loop(knl):
                                                           tuple(prim.Subscript(prim.Variable(lhsname),
                                                                                (vector_indices.get(horizontal) + i, prim.Variable(vec_iname)))
                                                                 for i in range(horizontal))),
-                                                depends_on=frozenset({Tagged("vec_write")}),
+                                                depends_on=frozenset({Tagged("vec_write{}".format(suffix))}),
                                                 within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
                                                 within_inames_is_final=True,
-                                                id="{}_rotateback".format(lhsname),
+                                                id="{}_rotateback{}".format(lhsname, suffix),
                                                 ))
 
     # Add the necessary vector indices
@@ -290,18 +294,33 @@ def vectorize_quadrature_loop(knl):
                                        0,  # expression
                                        within_inames=common_inames,
                                        within_inames_is_final=True,
-                                       id="assign_{}".format(name),
+                                       id="assign_{}{}".format(name, suffix),
                                        ))
         new_insns.append(lp.Assignment(prim.Variable(name),  # assignee
                                        prim.Sum((prim.Variable(name), increment)),  # expression
                                        within_inames=common_inames.union(frozenset({outer_iname})),
                                        within_inames_is_final=True,
-                                       depends_on=frozenset({Tagged("vec_write"), "assign_{}".format(name)}),
+                                       depends_on=frozenset({Tagged("vec_write{}".format(suffix)), "assign_{}{}".format(name, suffix)}),
                                        depends_on_is_final=True,
-                                       id="update_{}".format(name),
+                                       id="update_{}{}".format(name, suffix),
                                        ))
 
     from loopy.kernel.creation import resolve_dependencies
     return resolve_dependencies(knl.copy(instructions=new_insns + other_insns,
                                          temporary_variables=temporaries,
                                          ))
+
+
+def vectorize_quadrature_loop(knl):
+    # Loop over the quadrature loops that exist in the kernel.
+    # This is implemented a bit hacky right now...
+    for key, inames in quadrature_inames._memoize_cache.items():
+        element = key[0][0]
+        if element is None:
+            suffix = ''
+        else:
+            from dune.perftool.pdelab.driver import FEM_name_mangling
+            suffix = "_{}".format(FEM_name_mangling(element))
+        knl = _vectorize_quadrature_loop(knl, inames, suffix)
+
+    return knl
diff --git a/test/sumfact/stokes/stokes.mini b/test/sumfact/stokes/stokes.mini
index e453f126..9260bae2 100644
--- a/test/sumfact/stokes/stokes.mini
+++ b/test/sumfact/stokes/stokes.mini
@@ -1,7 +1,8 @@
 __name = sumfact_stokes_{__exec_suffix}
 
-__exec_suffix = {diff_suffix}
+__exec_suffix = {diff_suffix}_{quad_suffix}
 diff_suffix = numdiff, symdiff | expand num
+quad_suffix = quadvec, nonquadvec | expand quad
 
 cells = 8 8
 extension = 1. 1.
@@ -12,5 +13,6 @@ extension = vtu
 
 [formcompiler]
 numerical_jacobian = 1, 0 | expand num
+vectorize_quad = 1, 0 | expand quad
 compare_l2errorsquared = 1e-12
 sumfact = 1
diff --git a/test/sumfact/stokes/stokes_dg.mini b/test/sumfact/stokes/stokes_dg.mini
index f252f641..e533756a 100644
--- a/test/sumfact/stokes/stokes_dg.mini
+++ b/test/sumfact/stokes/stokes_dg.mini
@@ -1,6 +1,8 @@
 __name = sumfact_stokes_dg_{__exec_suffix}
 
-__exec_suffix = symdiff, numdiff | expand num
+__exec_suffix = {diff_suffix}_{vec_suffix}
+diff_suffix = symdiff, numdiff | expand num
+vec_suffix = quadvec, nonquadvec | expand vec
 
 cells = 8 8
 extension = 1. 1.
@@ -12,5 +14,6 @@ extension = vtu
 
 [formcompiler]
 numerical_jacobian = 0, 1 | expand num
+vectorize_quad = 1, 0 | expand vec
 compare_l2errorsquared = 1e-8
 sumfact = 1
-- 
GitLab