From ebe7b92a7e8cde9dc39f9178e69cace3c945a28a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 20 Sep 2018 15:43:39 +0200
Subject: [PATCH] Add changes for nonlinear sumfact CG system

Needed in dune-perftool-playground for Navier stokes
---
 .../dune/perftool/loopy/transformations/vectorize_quad.py | 8 +++++++-
 python/dune/perftool/sumfact/symbolic.py                  | 4 ++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/python/dune/perftool/loopy/transformations/vectorize_quad.py b/python/dune/perftool/loopy/transformations/vectorize_quad.py
index f82f4829..f1b3ce12 100644
--- a/python/dune/perftool/loopy/transformations/vectorize_quad.py
+++ b/python/dune/perftool/loopy/transformations/vectorize_quad.py
@@ -197,6 +197,11 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
                 # 1. Rotating the input data
                 knl = add_vector_view(knl, quantity)
                 if horizontal > 1:
+                    # Pitfall: In the case of generating jacobians, the input needs to be rotated exactly once.
+                    predicates = frozenset()
+                    if common_inames:
+                        predicates = frozenset({prim.Comparison(prim.Sum(tuple(prim.Variable(i) for i in common_inames)), "==", 0)})
+
                     new_insns.append(lp.CallInstruction((),  # assignees
                                                         prim.Call(TransposeReg(vertical=vertical, horizontal=horizontal),
                                                                   tuple(prim.Subscript(prim.Variable(get_vector_view_name(quantity)),
@@ -207,6 +212,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
                                                         within_inames_is_final=True,
                                                         id="{}_rotate{}".format(quantity, suffix),
                                                         tags=frozenset({"sumfact_stage2"}),
+                                                        predicates=predicates,
                                                         ))
 
                 # Add substitution rules
@@ -263,7 +269,7 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
                                                       (vector_indices.get(horizontal) + last_index, prim.Variable(vec_iname)),
                                                       ),
                                        substitute(insn.expression, replacemap),
-                                       depends_on=frozenset({"continue_stmt{}".format(suffix)}),
+                                       depends_on=frozenset({"continue_stmt{}".format(suffix), lp.match.Tagged("sumfact_stage1")}),
                                        depends_on_is_final=True,
                                        within_inames=common_inames.union(frozenset({outer_iname, vec_iname})),
                                        within_inames_is_final=True,
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 85601c1e..435975e4 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -155,7 +155,7 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
         outputs = set(self.interfaces)
 
         trial_element, = set(o.trial_element for o in self.interfaces)
-        trial_element_index, = set(o.trial_element_index for o in self.interfaces)
+        trial_element_index = set(o.trial_element_index for o in self.interfaces).pop()
         from dune.perftool.sumfact.accumulation import accum_iname
         element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
         inames = tuple(accum_iname(element, mat.rows, i)
@@ -368,7 +368,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
     @property
     def parallel_key(self):
         """ A key that identifies parallellizable kernels. """
-        return tuple(m.basis_size for m in self.permuted_matrix_sequence) + (self.stage, self.buffer)
+        return tuple(m.basis_size for m in self.permuted_matrix_sequence) + (self.stage, self.buffer, self.interface.within_inames)
 
     @property
     def cache_key(self):
-- 
GitLab