Skip to content
Snippets Groups Projects
Commit 24391f12 authored by René Heß's avatar René Heß
Browse files

Choose vectorization strategy based on cost permuted matrix sequence

parent 8b1b7acd
No related branches found
No related tags found
No related merge requests found
......@@ -376,7 +376,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
@property
def parallel_key(self):
""" A key that identifies parallellizable kernels. """
return tuple(m.basis_size for m in self.permuted_matrix_sequence) + (self.stage, self.buffer, self.interface.within_inames)
# TODO: For now we do not vectorize SumfactKernels with different
# quadrature_permutation. This should be handled like upper/lower
# vectorization
return self.quadrature_permutation + tuple(m.basis_size for m in self.permuted_matrix_sequence) + (self.stage, self.buffer, self.interface.within_inames)
@property
def cache_key(self):
......@@ -576,8 +579,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def operations(self):
""" The total number of floating point operations for the kernel
to be carried out """
from dune.codegen.sumfact.permutation import flop_cost
return flop_cost(self.permuted_matrix_sequence)
from dune.codegen.sumfact.permutation import flop_cost, sumfact_permutation_strategy, permute_forward
perm = sumfact_permutation_strategy(self)
permuted_matrix_sequence_cost = permute_forward(self.matrix_sequence, perm)
return flop_cost(permuted_matrix_sequence_cost)
# Extract the argument list and store it on the class. This needs to be done
......@@ -881,5 +886,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def operations(self):
""" The total number of floating point operations for the kernel
to be carried out """
from dune.codegen.sumfact.permutation import flop_cost
return flop_cost(self.matrix_sequence)
from dune.codegen.sumfact.permutation import flop_cost, sumfact_permutation_strategy, permute_forward
perm = sumfact_permutation_strategy(self)
permuted_matrix_sequence_cost = permute_forward(self.matrix_sequence, perm)
return flop_cost(permuted_matrix_sequence_cost)
......@@ -15,6 +15,7 @@ extension = vtu
[formcompiler]
compare_l2errorsquared = 1e-4, 5e-6 | expand deg
debug_interpolate_input = 1
[formcompiler.r]
numerical_jacobian = 1, 0 | expand num
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment