diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py index dfd9383f93c79d3c50bfcca24cdd717fcc16aa58..8d81ea0462541cff4e952f37bd02d34c2e36454a 100644 --- a/python/dune/codegen/sumfact/symbolic.py +++ b/python/dune/codegen/sumfact/symbolic.py @@ -562,6 +562,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ImmutableRecord.__init__(self, **defaultdict) prim.Variable.__init__(self, "SUMFACT") + # Precompute and cache a number of keys + self._cached_cache_key = None + # # The methods/fields needed to get a well-formed pymbolic node # @@ -617,12 +620,15 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): Any two sum factorization kernels having the same cache_key are realized simultaneously! """ - if self.buffer is None: - # During dry run, we return something unique to this kernel - return repr(self) - else: - # Later we identify parallely implemented kernels by the assigned buffer - return self.buffer + if self._cached_cache_key is None: + if self.buffer is None: + # During dry run, we return something unique to this kernel + self._cached_cache_key = repr(self) + else: + # Later we identify parallely implemented kernels by the assigned buffer + self._cached_cache_key = self.buffer + + return self._cached_cache_key @property def inout_key(self): @@ -865,6 +871,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) prim.Variable.__init__(self, "VecSUMFAC") + # Precompute and cache a number of keys + self._cached_cache_key = None + def __getinitargs__(self): return (self.kernels, self.horizontal_width, self.vertical_width, self.buffer, self.insn_dep) @@ -897,7 +906,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) Any two sum factorization kernels having the same cache_key are realized simulatenously! """ - return (self.matrix_sequence_quadrature_permuted, self.restriction, self.stage, self.buffer) + if self._cached_cache_key is None: + self._cached_cache_key = (self.matrix_sequence_quadrature_permuted, self.restriction, self.stage, self.buffer) + + return self._cached_cache_key # # Deduce all data fields of normal sum factorization kernels from the underlying kernels