From 85e2dd2dfbf3305d8a1f16541436cb75b68d60d9 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 15 Sep 2017 16:49:20 +0200
Subject: [PATCH] [bugfix] Correct quadrature inames at sumfact kernel output

When doing jacobians of nonlinear systems, the inames used for
a given sum factorization kernel output depend on the accumulation
context.
---
 python/dune/perftool/sumfact/__init__.py     | 15 ++----
 python/dune/perftool/sumfact/accumulation.py |  2 +-
 python/dune/perftool/sumfact/basis.py        | 14 +++---
 python/dune/perftool/sumfact/geometry.py     | 18 ++++----
 python/dune/perftool/sumfact/symbolic.py     | 48 ++++++++++++++------
 5 files changed, 55 insertions(+), 42 deletions(-)

diff --git a/python/dune/perftool/sumfact/__init__.py b/python/dune/perftool/sumfact/__init__.py
index df246dee..a4fce883 100644
--- a/python/dune/perftool/sumfact/__init__.py
+++ b/python/dune/perftool/sumfact/__init__.py
@@ -46,23 +46,19 @@ class SumFactInterface(PDELabInterface):
         return ret
 
     def pymbolic_trialfunction_gradient(self, element, restriction, index):
-        ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor.indices)
-        self.visitor.indices = indices
+        ret = pymbolic_coefficient_gradient(element, restriction, index, name_coefficientcontainer, self.visitor)
         return ret
 
     def pymbolic_trialfunction(self, element, restriction, index):
-        ret, indices = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor.indices)
-        self.visitor.indices = indices
+        ret = pymbolic_coefficient(element, restriction, index, name_coefficientcontainer, self.visitor)
         return ret
 
     def pymbolic_apply_function_gradient(self, element, restriction, index):
-        ret, indices = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor.indices)
-        self.visitor.indices = indices
+        ret = pymbolic_coefficient_gradient(element, restriction, index, name_applycontainer, self.visitor)
         return ret
 
     def pymbolic_apply_function(self, element, restriction, index):
-        ret, indices = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor.indices)
-        self.visitor.indices = indices
+        ret = pymbolic_coefficient(element, restriction, index, name_applycontainer, self.visitor)
         return ret
 
     def quadrature_inames(self):
@@ -73,8 +69,7 @@ class SumFactInterface(PDELabInterface):
 
     def pymbolic_spatial_coordinate(self):
         import dune.perftool.sumfact.geometry
-        ret, indices = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.indices, self.visitor.do_predicates, self.visitor)
-        self.visitor.indices = indices
+        ret = get_backend(interface="spatial_coordinate", selector=option_switch("diagonal_transformation_matrix"))(self.visitor.do_predicates, self.visitor)
         return ret
 
     def pymbolic_unit_outer_normal(self):
diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index dac38218..f8ef3940 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -320,7 +320,7 @@ def generate_accumulation_instruction(expr, visitor):
     # Issue an instruction in the quadrature loop that fills the buffer
     # with the evaluation of the contribution at all quadrature points
     assignee = prim.Subscript(lp.TaggedVariable(temp, vsf.tag),
-                              vsf.quadrature_index(sf))
+                              vsf.quadrature_index(sf, visitor))
     contrib_dep = instruction(assignee=assignee,
                               expression=expr,
                               forced_iname_deps=frozenset(quadrature_inames(trial_leaf_element) + jacobian_inames),
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index d01c8ef8..c1ca1a84 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -111,10 +111,9 @@ def _basis_functions_per_direction(element):
     return basis_size
 
 
-@kernel_cached
-def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor_indices):
+def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visitor):
     sub_element = element
-    grad_index = visitor_indices[0]
+    grad_index = visitor.indices[0]
     if isinstance(element, MixedElement):
         sub_element = element.extract_component(index)[1]
 
@@ -154,11 +153,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     var, insn_dep = realize_sum_factorization_kernel(vsf)
 
-    return prim.Subscript(var, vsf.quadrature_index(sf)), None
+    visitor.indices = None
+    return prim.Subscript(var, vsf.quadrature_index(sf, visitor))
 
 
-@kernel_cached
-def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indices):
+def pymbolic_coefficient(element, restriction, index, coeff_func, visitor):
     sub_element = element
     if isinstance(element, MixedElement):
         sub_element = element.extract_component(index)[1]
@@ -197,7 +196,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     var, _ = realize_sum_factorization_kernel(vsf)
 
-    return prim.Subscript(var, vsf.quadrature_index(sf)), None
+    visitor.indices = None
+    return prim.Subscript(var, vsf.quadrature_index(sf, visitor))
 
 
 @iname
diff --git a/python/dune/perftool/sumfact/geometry.py b/python/dune/perftool/sumfact/geometry.py
index e7d6c127..9f0fbcbd 100644
--- a/python/dune/perftool/sumfact/geometry.py
+++ b/python/dune/perftool/sumfact/geometry.py
@@ -67,10 +67,9 @@ class GeoCornersInput(SumfactKernelInputBase, ImmutableRecord):
                     )
 
 
-@kernel_cached
 @backend(interface="spatial_coordinate", name="default")
-def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visitor):
-    assert len(visitor_indices) == 1
+def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
+    assert len(visitor.indices) == 1
 
     # Construct the matrix sequence for the evaluation of the global coordinate.
     # We need to manually construct this one, because on facets, we want to use the
@@ -80,7 +79,7 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi
     from dune.perftool.sumfact.tabulation import quadrature_points_per_direction, BasisTabulationMatrix
     quadrature_size = quadrature_points_per_direction()
     matrix_sequence = (BasisTabulationMatrix(quadrature_size=quadrature_size, basis_size=2),) * local_dimension()
-    inp = GeoCornersInput(visitor_indices[0])
+    inp = GeoCornersInput(visitor.indices[0])
 
     from dune.perftool.sumfact.symbolic import SumfactKernel
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
@@ -94,7 +93,8 @@ def pymbolic_spatial_coordinate_multilinear(visitor_indices, do_predicates, visi
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     var, _ = realize_sum_factorization_kernel(vsf)
 
-    return prim.Subscript(var, vsf.quadrature_index(sf)), None
+    visitor.indices = None
+    return prim.Subscript(var, vsf.quadrature_index(sf, visitor)), None
 
 
 @preamble
@@ -126,11 +126,10 @@ def name_meshwidth():
     return name
 
 
-@kernel_cached
 @backend(interface="spatial_coordinate", name="diagonal_transformation_matrix")
-def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visitor):
-    assert len(visitor_indices) == 1
-    index, = visitor_indices
+def pymbolic_spatial_coordinate_axiparallel(do_predicates, visitor):
+    assert len(visitor.indices) == 1
+    index, = visitor.indices
 
     # Urgh: *SOMEHOW* construct a face direction
     from dune.perftool.pdelab.restriction import Restriction
@@ -159,6 +158,7 @@ def pymbolic_spatial_coordinate_axiparallel(visitor_indices, do_predicates, visi
         from dune.perftool.sumfact.quadrature import pymbolic_quadrature_position
         x = pymbolic_quadrature_position(iindex, visitor)
 
+    visitor.indices = None
     return prim.Subscript(prim.Variable(lowcorner), (index,)) + x * prim.Subscript(prim.Variable(meshwidth), (index,)), None
 
 
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index a5c4eb06..88446417 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -216,10 +216,16 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         """
         return tuple(mat.quadrature_size for mat in self.matrix_sequence)
 
-    def quadrature_index(self, _):
-        element = self.trial_element
-        if element is not None and isinstance(element, MixedElement):
-            element = element.extract_component(self.trial_element_index)[1]
+    def quadrature_index(self, sf, visitor):
+        if visitor.current_info[1] is None:
+            element = None
+            element_index = 0
+        else:
+            element = visitor.current_info[1].element
+            element_index = visitor.current_info[1].element_index
+            if isinstance(element, MixedElement):
+                element = element.extract_component(element_index)[1]
+
         quad_inames = quadrature_inames(element)
         if len(self.matrix_sequence) == local_dimension():
             return tuple(prim.Variable(i) for i in quad_inames)
@@ -474,10 +480,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
                 return i
         return 0
 
-    def _quadrature_index(self, sf):
-        element = self.trial_element
-        if element is not None and isinstance(element, MixedElement):
-            element = element.extract_component(self.trial_element_index)[1]
+    def _quadrature_index(self, sf, visitor):
+        if visitor.current_info[1] is None:
+            element = None
+            element_index = 0
+        else:
+            element = visitor.current_info[1].element
+            element_index = visitor.current_info[1].element_index
+            if isinstance(element, MixedElement):
+                element = element.extract_component(element_index)[1]
+
         quad_inames = quadrature_inames(element)
         index = []
 
@@ -506,10 +518,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
         return tuple(index)
 
-    def vec_index(self, sf):
-        element = self.trial_element
-        if element is not None and isinstance(element, MixedElement):
-            element = element.extract_component(self.trial_element_index)[1]
+    def vec_index(self, sf, visitor):
+        if visitor.current_info[1] is None:
+            element = None
+            element_index = 0
+        else:
+            element = visitor.current_info[1].element
+            element_index = visitor.current_info[1].element_index
+            if isinstance(element, MixedElement):
+                element = element.extract_component(element_index)[1]
+
         quad_inames = quadrature_inames(element)
         sliced = 0
         if len(sf.matrix_sequence) == local_dimension():
@@ -530,13 +548,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     def quadrature_shape(self):
         return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)
 
-    def quadrature_index(self, sf, direct_index=None):
-        quad = self._quadrature_index(sf)
+    def quadrature_index(self, sf, visitor, direct_index=None):
+        quad = self._quadrature_index(sf, visitor)
         if direct_index is not None:
             assert isinstance(direct_index, tuple)
             return quad + direct_index
         else:
-            return quad + (self.vec_index(sf),)
+            return quad + (self.vec_index(sf, visitor),)
 
     @property
     def quadrature_dimtags(self):
-- 
GitLab