From c29aa746719862dbc6bc433bc865aa161b77626c Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 5 Dec 2016 17:44:09 +0100
Subject: [PATCH] Fix gradvec strategy

---
 python/dune/perftool/sumfact/basis.py   | 40 +++++++++++++++++--------
 python/dune/perftool/sumfact/sumfact.py | 18 +++++------
 2 files changed, 36 insertions(+), 22 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 7221e89d..62077833 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -27,6 +27,7 @@ from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.loopy.buffer import initialize_buffer
 from dune.perftool.pdelab.driver import FEM_name_mangling
 from dune.perftool.pdelab.restriction import restricted_name
+from dune.perftool.tools import maybe_wrap_subscript
 
 from pytools import product
 
@@ -67,6 +68,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
     formdata = get_global_context_value('formdata')
     dim = formdata.geometric_dimension
     buffers = []
+    insn_dep = None
+    ret = False
     for i in range(dim):
         a_matrices = [theta_matrix] * dim
         a_matrices[i] = dtheta_matrix
@@ -86,6 +89,8 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
                                   ).get_temporary(shape=shape,
                                                   name=input,
                                                   )
+        if insn_dep is None:
+            insn_dep = frozenset({Writes(input)})
 
         # Setup the input!
         setup_theta(input, element, restriction, component, index)
@@ -94,22 +99,31 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
         if index:
-            var, _ = sum_factorization_kernel(a_matrices,
-                                              buffer,
-                                              1,
-                                              preferred_position=i,
-                                              insn_dep=frozenset({Writes(input)}),
-                                              output_name=name,
-                                              )
+            assert len(visitor.indices) == 1
+            var, insn_dep = sum_factorization_kernel(a_matrices,
+                                                     buffer,
+                                                     1,
+                                                     preferred_position=i,
+                                                     insn_dep=insn_dep,
+                                                     )
+            ret = True
         else:
-            var, _ = sum_factorization_kernel(a_matrices,
-                                              buffer,
-                                              1,
-                                              preferred_position=i,
-                                              insn_dep=frozenset({Writes(input)}),
-                                              )
+            var, insn_dep = sum_factorization_kernel(a_matrices,
+                                                     buffer,
+                                                     1,
+                                                     preferred_position=i,
+                                                     insn_dep=insn_dep,
+                                                     )
             buffers.append(var)
 
+    # Check whether we want to return early with something that has the indexing
+    # already handled! This happens with vectorization when the index coincides
+    # with the position in the vector register.
+    if ret:
+        indices = visitor.indices
+        visitor.indices = None
+        return maybe_wrap_subscript(var, tuple(prim.Variable(i) for i in quadrature_inames()) + indices)
+
     # TODO this should be quite conditional!!!
     for i, buf in enumerate(buffers):
         # Write solution from sumfactorization to gradient variable
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 059488f9..5787913e 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -11,6 +11,7 @@ from dune.perftool.generation import (backend,
                                       built_instruction,
                                       domain,
                                       function_mangler,
+                                      generator_factory,
                                       get_counter,
                                       get_global_context_value,
                                       globalarg,
@@ -120,10 +121,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     else:
         buffers.append(name_test_function_contribution(accterm.argument))
 
-    # Collect a list of instruction IDs of the contributions
-    contribution_ids = []
-
-    # TODO covers only 2D
+    insn_dep = None
     for i, buf in enumerate(buffers):
         # Get the a matrices needed for this accumulation term
         rows = basis_functions_per_direction()
@@ -188,14 +186,16 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                                   tags=frozenset({"quadvec"}),
                                   depends_on=frozenset({deps})
                                   )
-        contribution_ids.append(contrib_dep)
+
+        if insn_dep is None:
+            insn_dep = frozenset({contrib_dep})
 
         # Add a sum factorization kernel that implements the multiplication
         # with the test function (stage 3)
         result, insn_dep = sum_factorization_kernel(a_matrices,
                                                     buffer,
                                                     3,
-                                                    insn_dep=frozenset({contrib_dep}),
+                                                    insn_dep=insn_dep,
                                                     additional_inames=frozenset(visitor.inames),
                                                     preferred_position=pref_pos,
                                                     )
@@ -223,7 +223,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         expr = Call(PDELabAccumulationFunction(accum, rank),
                     (ansatz_lfs.get_args() +
                      test_lfs.get_args() +
-                     (Subscript(result, tuple(Variable(i) for i in inames)),)
+                     (Subscript(result, tuple(Variable(i) for i in inames) + index),)
                      )
                     )
         instruction(assignees=(),
@@ -237,7 +237,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         transform(nest_quadrature_loops, visitor.inames)
 
 
-def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None, output_name=None):
+@generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a,b,s))
+def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), additional_inames=frozenset({}), preferred_position=None):
     """
     Calculate a sum factorization matrix product.
 
@@ -340,7 +341,6 @@ def sum_factorization_kernel(a_matrices, buf, stage, insn_dep=frozenset({}), add
     out = get_buffer_temporary(buf,
                                shape=out_shape,
                                dim_tags=dim_tags,
-                               name=output_name,
                                )
     silenced_warning('read_no_write({})'.format(out))
 
-- 
GitLab