Skip to content
Snippets Groups Projects
Commit 12cac856 authored by René Heß's avatar René Heß
Browse files

Fix bug in horizontal_index in loop_splitting case

parent c6513ccf
No related branches found
No related tags found
No related merge requests found
...@@ -994,7 +994,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -994,7 +994,16 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def horizontal_index(self, sf): def horizontal_index(self, sf):
for i, k in enumerate(self.kernels): for i, k in enumerate(self.kernels):
if sf.inout_key == k.inout_key: # We need to identify to which part of the vectorized kernel sf
# corresponds. Since splitting might change the cost_permutation we
# exclude it in the comparison below. We also make sure to check
# that derivatives are the same.
from copy import deepcopy
sf_interface = deepcopy(sf.interface)
sf_interface._cost_permutation=None
k_interface = deepcopy(k.interface)
k_interface._cost_permutation=None
if repr(sf_interface) == repr(k_interface):
if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted): if tuple(mat.derivative for mat in sf.matrix_sequence_quadrature_permuted) == tuple(mat.derivative for mat in k.matrix_sequence_quadrature_permuted):
return i return i
...@@ -1050,7 +1059,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -1050,7 +1059,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
quad_inames = quadrature_inames(element) quad_inames = quadrature_inames(element)
sliced = 0 sliced = 0
if len(sf.matrix_sequence_quadrature_permuted) == local_dimension(): if len(self.matrix_sequence_quadrature_permuted) == local_dimension():
for d in range(local_dimension()): for d in range(local_dimension()):
if self.matrix_sequence_quadrature_permuted[d].slice_size: if self.matrix_sequence_quadrature_permuted[d].slice_size:
sliced = prim.Variable(quad_inames[d]) sliced = prim.Variable(quad_inames[d])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment