Skip to content
Snippets Groups Projects
Commit 901cdcf9 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Cache the flopcost of a sumfact kernel object

parent d4f90764
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import itertools import itertools
from dune.codegen.options import get_option from dune.codegen.options import get_option
from dune.codegen.sumfact.tabulation import quadrature_points_per_direction
from dune.codegen.ufl.modified_terminals import Restriction from dune.codegen.ufl.modified_terminals import Restriction
......
...@@ -14,7 +14,7 @@ from dune.codegen.sumfact.permutation import (flop_cost, ...@@ -14,7 +14,7 @@ from dune.codegen.sumfact.permutation import (flop_cost,
sumfact_cost_permutation_strategy, sumfact_cost_permutation_strategy,
sumfact_quadrature_permutation_strategy, sumfact_quadrature_permutation_strategy,
) )
from dune.codegen.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray from dune.codegen.sumfact.tabulation import BasisTabulationMatrixBase, BasisTabulationMatrixArray, quadrature_points_per_direction
from dune.codegen.loopy.target import dtype_floatingpoint, type_floatingpoint from dune.codegen.loopy.target import dtype_floatingpoint, type_floatingpoint
from dune.codegen.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad from dune.codegen.loopy.vcl import ExplicitVCLCast, VCLLowerUpperLoad
from dune.codegen.tools import get_leaf, maybe_wrap_subscript, remove_duplicates from dune.codegen.tools import get_leaf, maybe_wrap_subscript, remove_duplicates
...@@ -564,6 +564,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -564,6 +564,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
# Precompute and cache a number of keys # Precompute and cache a number of keys
self._cached_cache_key = None self._cached_cache_key = None
self._cached_flop_cost = {}
# #
# The methods/fields needed to get a well-formed pymbolic node # The methods/fields needed to get a well-formed pymbolic node
...@@ -824,7 +825,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -824,7 +825,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def operations(self): def operations(self):
""" The total number of floating point operations for the kernel """ The total number of floating point operations for the kernel
to be carried out """ to be carried out """
return flop_cost(self.matrix_sequence_cost_permuted) qp = quadrature_points_per_direction()
if qp not in self._cached_flop_cost:
self._cached_flop_cost[qp] = flop_cost(self.matrix_sequence_cost_permuted)
return self._cached_flop_cost[qp]
# Extract the argument list and store it on the class. This needs to be done # Extract the argument list and store it on the class. This needs to be done
...@@ -873,6 +877,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -873,6 +877,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Precompute and cache a number of keys # Precompute and cache a number of keys
self._cached_cache_key = None self._cached_cache_key = None
self._cached_flop_cost = {}
def __getinitargs__(self): def __getinitargs__(self):
return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep) return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep)
...@@ -1135,4 +1140,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -1135,4 +1140,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def operations(self): def operations(self):
""" The total number of floating point operations for the kernel """ The total number of floating point operations for the kernel
to be carried out """ to be carried out """
return flop_cost(self.matrix_sequence_cost_permuted) qp = quadrature_points_per_direction()
if qp not in self._cached_flop_cost:
self._cached_flop_cost[qp] = flop_cost(self.matrix_sequence_cost_permuted)
return self._cached_flop_cost[qp]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment