From 12cac85620be8c1e67b4426ea0df7dae801213a4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Fri, 21 Dec 2018 10:49:24 +0100
Subject: [PATCH] Fix bug in horizontal_index in loop_splitting case

---
 python/dune/codegen/sumfact/symbolic.py | 13 +++++++++++--
 1 file changed, 11 insertions(+), 2 deletions(-)

diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 45cd7548..25b3a795 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -994,7 +994,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     def horizontal_index(self, sf):
         for i, k in enumerate(self.kernels):
-            if sf.inout_key == k.inout_key:
+            # We need to identify to which part of the vectorized kernel sf
+            # corresponds. Since splitting might change the cost_permutation we
+            # exclude it in the comparison below. We also make sure to check
+            # that derivatives are the same.
+            from copy import deepcopy
+            sf_interface = deepcopy(sf.interface)
+            sf_interface._cost_permutation=None
+            k_interface = deepcopy(k.interface)
+            k_interface._cost_permutation=None
+            if repr(sf_interface) == repr(k_interface):
                 if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted):
                     return i
 
@@ -1050,7 +1059,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
         quad_inames = quadrature_inames(element)
         sliced = 0
-        if len(sf.matrix_sequence_quadrature_permuted) == local_dimension():
+        if len(self.matrix_sequence_quadrature_permuted) == local_dimension():
             for d in range(local_dimension()):
                 if self.matrix_sequence_quadrature_permuted[d].slice_size:
                     sliced = prim.Variable(quad_inames[d])
-- 
GitLab