diff --git a/python/dune/codegen/sumfact/permutation.py b/python/dune/codegen/sumfact/permutation.py index 6e9fdaaad5dc9021d0ec076df7cc743806eab512..916f7773592191a7dcc0731e56b93524daf4528e 100644 --- a/python/dune/codegen/sumfact/permutation.py +++ b/python/dune/codegen/sumfact/permutation.py @@ -3,6 +3,7 @@ import itertools from dune.codegen.options import get_option +from dune.codegen.sumfact.tabulation import quadrature_points_per_direction from dune.codegen.ufl.modified_terminals import Restriction diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index 8d81ea0462541cff4e952f37bd02d34c2e36454a..8fdd1dd8d23ddca5745be5acb4e02837478d4c9d 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -14,7 +14,7 @@ from dune.codegen.sumfact.permutation import (flop_cost, sumfact_cost_permutation_strategy, sumfact_quadrature_permutation_strategy, ) -from dune.codegen.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray +from dune.codegen.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray, quadrature_points_per_direction from dune.codegen.loopy.target import dtype_floatingpoint, type_floatingpoint from dune.codegen.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad from dune.codegen.tools import get_leaf, maybe_wrap_subscript, remove_duplicates @@ -564,6 +564,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): # Precompute and cache a number of keys self._cached_cache_key = None + self._cached_flop_cost = {} # # The methods/fields needed to get a well-formed pymbolic node @@ -824,7 +825,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): def operations(self): """ The total number of floating point operations for the kernel to be carried out """ - return flop_cost(self.matrix_sequence_cost_permuted) + qp = quadrature_points_per_direction() + if qp not in self._cached_flop_cost: + self._cached_flop_cost[qp] = flop_cost(self.matrix_sequence_cost_permuted) + return self._cached_flop_cost[qp] # Extract the argument list and store it on the class. This needs to be done @@ -873,6 +877,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) # Precompute and cache a number of keys self._cached_cache_key = None + self._cached_flop_cost = {} def __getinitargs__(self): return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep) @@ -1135,4 +1140,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def operations(self): """ The total number of floating point operations for the kernel to be carried out """ - return flop_cost(self.matrix_sequence_cost_permuted) + qp = quadrature_points_per_direction() + if qp not in self._cached_flop_cost: + self._cached_flop_cost[qp] = flop_cost(self.matrix_sequence_cost_permuted) + return self._cached_flop_cost[qp]