diff --git a/python/dune/codegen/sumfact/vectorization.py b/python/dune/codegen/sumfact/vectorization.py index 674078fd54249736249047001556c4ce8bc1b289..0b6b3a232bbc5b199b853c022743c8a18884f3fe 100644 --- a/python/dune/codegen/sumfact/vectorization.py +++ b/python/dune/codegen/sumfact/vectorization.py @@ -13,6 +13,7 @@ from dune.codegen.generation import (backend, get_counted_variable, get_global_context_value, kernel_cached, + retrieve_cache_items, ) from dune.codegen.pdelab.restriction import (Restriction, restricted_name, @@ -151,7 +152,7 @@ class PrimitiveApproximateOpcounter(FlopCounter): raise NotImplementedError("The class {} should implement a symbolic flopcounter.".format(type(expr))) -@kernel_cached +@generator_factory(item_tags=("opcounts",), context_tags="kernel") def store_operation_count(expr, count): return count @@ -183,7 +184,7 @@ def quadrature_penalized_strategy_cost(strat_tuple): # Determine the number of floating point operations per quadrature point. # This flop counting is a very crude approximation, but that is totally sufficient here. - ops_per_qp = sum(i.value for i in store_operation_count._memoize_cache.values()) + ops_per_qp = sum(i for i in retrieve_cache_items("opcounts")) # Do the actual scaling. return float((sf_flops + ops_per_qp * num_qp_new) / (sf_flops + ops_per_qp * num_qp_old)) * cost