diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py index d0c6fba68e42c7354d2818395454e2e4de254ff9..5f396035bc1adb8f35776644389d081d2b6c8a21 100644 --- a/python/dune/perftool/sumfact/accumulation.py +++ b/python/dune/perftool/sumfact/accumulation.py @@ -389,6 +389,10 @@ def generate_accumulation_instruction(expr, visitor): # Cache all stage 1 sum factorization kernels used in this expression SumfactCollectMapper()(expr) + # Count flops on the expression for the vectorization decision making algorithm + from dune.perftool.sumfact.vectorization import count_quadrature_point_operations + count_quadrature_point_operations(expr) + # Number of basis functions per direction leaf_element = test_info.element from ufl import MixedElement diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index cb6e879312bfe8816e447b826a87953271475704..c7eac157cc02ee608244dc499c4d6f7f93c18b67 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -12,6 +12,7 @@ from dune.perftool.generation import (backend, get_backend, get_counted_variable, get_global_context_value, + kernel_cached, ) from dune.perftool.pdelab.restriction import (Restriction, restricted_name, @@ -24,6 +25,7 @@ from dune.perftool.error import PerftoolVectorizationError from dune.perftool.options import get_form_option, get_option, set_form_option from dune.perftool.tools import add_to_frozendict, round_to_multiple, list_diff +from pymbolic.mapper.flop_counter import FlopCounter from pytools import product from frozendict import frozendict import itertools as it @@ -136,19 +138,49 @@ def strategy_cost(strat_tuple): return accumulate_for_strategy(strategy, lambda sf: float(func(sf))) +class PrimitiveApproximateOpcounter(FlopCounter): + def map_sumfact_kernel(self, expr): + return 0 + + def map_tagged_variable(self, expr): + return self.map_variable(expr) + + +@kernel_cached +def store_operation_count(expr, count): + return count + + +def count_quadrature_point_operations(expr): + counter = PrimitiveApproximateOpcounter() + store_operation_count(expr, counter(expr)) + + def quadrature_penalized_strategy_cost(strat_tuple): + """ Implements a penalization of the cost function that accounts for + the increase in flops that occur in the quadrature loop. This needs to + somehow get a guess of how much work is done in the quadrature loop relative + to the sum factorization kernels. + """ + qp, strategy = strat_tuple + + # Evaluate the original cost function. This result will be scaled by this function. cost = strategy_cost(strat_tuple) - qp, strategy = strat_tuple - num_qp_new = product(qp) + # Get the total number of Flops done in sum factorization kernels sf_flops = accumulate_for_strategy(strategy, lambda sf: sf.operations) + # Get the minimal possible number of quadrature points and the actual quadrature points + num_qp_new = product(qp) set_quadrature_points(None) num_qp_old = product(quadrature_points_per_direction()) set_quadrature_points(qp) - # TODO determine this - ops_per_qp = 100 + # Determine the number of floating point operations per quadrature point. + # This flop counting is a very crude approximation, but that is totally sufficient here. + ops_per_qp = sum(i.value for i in store_operation_count._memoize_cache.values()) + + # Do the actual scaling. return float((sf_flops + ops_per_qp * num_qp_new) / (sf_flops + ops_per_qp * num_qp_old)) * cost