From 69158898534d77bfcb3ec1a0ec1f3592e5ee6809 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 22 Nov 2016 14:23:29 +0100
Subject: [PATCH] Implement vectorization for poisson

---
 .../loopy/transformations/collect_rotate.py   | 41 +++++++++++++------
 python/dune/perftool/pdelab/localoperator.py  | 15 +++++++
 python/dune/perftool/sumfact/basis.py         | 26 +-----------
 python/dune/perftool/sumfact/sumfact.py       |  6 +--
 test/sumfact/poisson/poisson_order1.mini      |  6 ++-
 test/sumfact/poisson/poisson_order2.mini      |  6 ++-
 6 files changed, 56 insertions(+), 44 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/collect_rotate.py b/python/dune/perftool/loopy/transformations/collect_rotate.py
index 16a472c8..13b6d505 100644
--- a/python/dune/perftool/loopy/transformations/collect_rotate.py
+++ b/python/dune/perftool/loopy/transformations/collect_rotate.py
@@ -66,14 +66,10 @@ def collect_vector_data_rotate(knl, insns, inames):
             basename = get_pymbolic_basename(expr)
             quantities.setdefault(basename, frozenset())
             quantities[basename] = quantities[basename].union(frozenset([expr]))
-    assert all(len(q) == 1 for q in quantities.values())
 
     # Add vector size buffers for all these quantities
-    replacemap_arr = {}
     replacemap_vec = {}
     for quantity in quantities:
-        expr, = quantities[quantity]
-
         # Check whether there is an instruction that writes this quantity within
         # the given inames. If so, we need a buffer array.
         iname_match = lp.match.And(tuple(lp.match.Iname(i) for i in inames))
@@ -83,38 +79,57 @@ def collect_vector_data_rotate(knl, insns, inames):
         all_writers.extend([i.id for i in write_insns])
 
         if write_insns:
+            # Determine the shape of the quantity
+            shape = knl.temporary_variables[quantity].shape
+
             arrname = quantity + '_buffered_arr'
             knl = add_temporary_with_vector_view(knl,
                                                  arrname,
                                                  dtype=np.float64,
-                                                 shape=(vec_size,),
-                                                 dim_tags="c",
+                                                 shape=shape + (vec_size,),
+                                                 dim_tags=",".join("c" for i in range(len(shape) + 1)),
                                                  base_storage=quantity + '_base_storage',
                                                  scope=lp.temp_var_scope.PRIVATE,
                                                  )
 
-            replacemap_arr[quantity] = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
-            replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), (0, prim.Variable(new_iname),))
+            def get_quantity_subscripts(e, zero=False):
+                if isinstance(e, prim.Subscript):
+                    index = e.index
+                    if isinstance(index, tuple):
+                        return index
+                    else:
+                        return (index,)
+                else:
+                    if zero:
+                        return (0,)
+                    else:
+                        return ()
+
+            for expr in quantities[quantity]:
+                replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(arrname)), get_quantity_subscripts(expr, zero=True) + (prim.Variable(new_iname),))
 
             for insn in write_insns:
                 if isinstance(insn, lp.Assignment):
-                    new_insns.append(insn.copy(assignee=replacemap_arr[get_pymbolic_basename(insn.assignee)],
+                    assignee = prim.Subscript(prim.Variable(arrname), get_quantity_subscripts(insn.assignee) + (prim.Variable('rotate_index'),))
+                    new_insns.append(insn.copy(assignee=assignee,
                                                depends_on_is_final=True,
                                                )
                                  )
                 elif isinstance(insn, lp.CInstruction):
                     # Rip apart the code and change the assignee
                     assignee, expression = insn.code.split("=")
-                    assignee = assignee.strip()
-                    assert assignee in replacemap_arr
 
-                    code = "{} ={}".format(str(replacemap_arr[assignee]), expression)
+                    # TODO This is a bit whacky: It only works for scalar assignees
+                    # OTOH this code is on its way out anyway, because of CInstruction
+                    assignee = prim.Subscript(prim.Variable(arrname), (prim.Variable('rotate_index'),))
+
+                    code = "{} ={}".format(str(assignee), expression)
                     new_insns.append(insn.copy(code=code,
                                                depends_on_is_final=True,
                                                ))
                 else:
                     raise NotImplementedError
-        else:
+        elif quantity in knl.temporary_variables:
             # Add a vector view to this quantity
             knl = add_vector_view(knl, quantity)
             replacemap_vec[expr] = prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index cd64a827..e0848537 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -26,6 +26,8 @@ from dune.perftool.ufl.modified_terminals import Restriction
 from pymbolic.primitives import Variable
 from pytools import Record
 
+import loopy as lp
+
 
 def name_form(formdata, data):
     # Check wether the formdata has a name in UFL
@@ -508,6 +510,19 @@ def generate_kernel(integrals):
     from dune.perftool.loopy import heuristic_duplication
     kernel = heuristic_duplication(kernel)
 
+    # Maybe apply vectorization strategies
+    if get_option("vectorize"):
+        if get_option("sumfact"):
+            # Vectorization of the quadrature loop
+            insns = [i.id for i in lp.find_instructions(kernel, lp.match.Tagged("quadvec"))]
+            from dune.perftool.sumfact.quadrature import quadrature_inames
+            inames = quadrature_inames()
+
+            from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate
+            kernel = collect_vector_data_rotate(kernel, insns, inames)
+        else:
+            raise NotImplementedError("Only vectorizing sumfactoized code right now!")
+
     # Now add the preambles to the kernel
     preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("preamble"))]
     kernel = kernel.copy(preambles=preambles)
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 43b0ff43..77edc9e0 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -41,32 +41,13 @@ def name_sumfact_base_buffer():
 
 @cached
 def sumfact_evaluate_coefficient_gradient(element, name, restriction, component):
-    # First we determine the rank of the tensor we are talking about
+    # Get a temporary for the gradient
     from ufl.functionview import select_subelement
     sub_element = select_subelement(element, component)
     rank = len(sub_element.value_shape()) + 1
-
-    # We do then set some variables accordingly
     shape = sub_element.value_shape() + (element.cell().geometric_dimension(),)
     shape_impl = ('arr',) * rank
-
-    from dune.perftool.pdelab.geometry import dimension_iname
-    idims = tuple(dimension_iname(count=i) for i in range(rank))
-    leaf_element = sub_element
-    from ufl import VectorElement, TensorElement
-    if isinstance(sub_element, (VectorElement, TensorElement)):
-        leaf_element = sub_element.sub_elements()[0]
-
-    # and proceed to call the necessary generator functions
     temporary_variable(name, shape=shape, shape_impl=shape_impl)
-    from dune.perftool.pdelab.spaces import name_lfs
-    lfs = name_lfs(element, restriction, component)
-    from dune.perftool.pdelab.basis import pymbolic_reference_gradient
-    basis = pymbolic_reference_gradient(leaf_element, restriction, 0, context='trialgrad')
-    from dune.perftool.tools import get_pymbolic_indices
-    index, _ = get_pymbolic_indices(basis)
-    if isinstance(sub_element, (VectorElement, TensorElement)):
-        lfs = lfs_child(lfs, idims[:-1], shape=shape_as_pymbolic(shape[:-1]), symmetry=element.symmetry())
 
     # Calculate values with sumfactorization
     theta = name_theta()
@@ -111,7 +92,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
         expression = Subscript(Variable(buf), tuple(Variable(i) for i in quadrature_inames()))
         instruction(assignee=assignee,
                     expression=expression,
-                    forced_iname_deps=frozenset(get_backend("quad_inames")()).union(frozenset(idims)),
+                    forced_iname_deps=frozenset(get_backend("quad_inames")()),
                     forced_iname_deps_is_final=True,
                     )
 
@@ -204,9 +185,6 @@ def pymbolic_basis(element, restriction, number):
 @backend(interface="evaluate_grad")
 @cached
 def evaluate_reference_gradient(element, name, restriction):
-    # from dune.perftool.pdelab.basis import name_leaf_lfs
-    # lfs = name_leaf_lfs(element, restriction)
-    # from dune.perftool.pdelab.spaces import name_lfs_bound
     from dune.perftool.pdelab.geometry import name_dimension
     temporary_variable(
         name,
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index f163217d..36240228 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -157,6 +157,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                   expression=expression,
                                   forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
                                   forced_iname_deps_is_final=True,
+                                  tags=frozenset({"quadvec"}),
                                   )
         contribution_ids.append(contrib_dep)
 
@@ -205,11 +206,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
         transform(nest_quadrature_loops, visitor.inames)
 
-    # Maybe try to vectorize!
-    if get_option("vectorize"):
-        from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate
-        transform(collect_vector_data_rotate, contribution_ids, quadrature_inames())
-
 
 def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional_inames=frozenset({})):
     """
diff --git a/test/sumfact/poisson/poisson_order1.mini b/test/sumfact/poisson/poisson_order1.mini
index 0c4919fe..75918406 100644
--- a/test/sumfact/poisson/poisson_order1.mini
+++ b/test/sumfact/poisson/poisson_order1.mini
@@ -1,5 +1,8 @@
 __name = sumfact_poisson_order1_{__exec_suffix}
-__exec_suffix = numdiff, symdiff | expand num
+__exec_suffix = {diff_suffix}_{vec_suffix}
+
+diff_suffix = numdiff, symdiff | expand num
+vec_suffix = vec, nonvec | expand vec
 
 cells = 8 8
 extension = 1. 1.
@@ -14,3 +17,4 @@ numerical_jacobian = 1, 0 | expand num
 exact_solution_expression = g
 compare_l2errorsquared = 1e-4
 sumfact = 1
+vectorize = 1, 0 | expand vec
diff --git a/test/sumfact/poisson/poisson_order2.mini b/test/sumfact/poisson/poisson_order2.mini
index 88e4f3b8..8830f173 100644
--- a/test/sumfact/poisson/poisson_order2.mini
+++ b/test/sumfact/poisson/poisson_order2.mini
@@ -1,5 +1,8 @@
 __name = sumfact_poisson_order2_{__exec_suffix}
-__exec_suffix = numdiff, symdiff | expand num
+__exec_suffix = {diff_suffix}_{vec_suffix}
+
+diff_suffix = numdiff, symdiff | expand num
+vec_suffix = vec, nonvec | expand vec
 
 cells = 8 8
 extension = 1. 1.
@@ -14,3 +17,4 @@ numerical_jacobian = 1, 0 | expand num
 exact_solution_expression = g
 compare_l2errorsquared = 1e-8
 sumfact = 1
+vectorize = 1, 0 | expand vec
-- 
GitLab