Skip to content
Snippets Groups Projects
Commit 1bc2247f authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Implement symbolic operation counting for penalized cost model

parent 2a56e7ba
No related branches found
No related tags found
No related merge requests found
...@@ -389,6 +389,10 @@ def generate_accumulation_instruction(expr, visitor): ...@@ -389,6 +389,10 @@ def generate_accumulation_instruction(expr, visitor):
# Cache all stage 1 sum factorization kernels used in this expression # Cache all stage 1 sum factorization kernels used in this expression
SumfactCollectMapper()(expr) SumfactCollectMapper()(expr)
# Count flops on the expression for the vectorization decision making algorithm
from dune.perftool.sumfact.vectorization import count_quadrature_point_operations
count_quadrature_point_operations(expr)
# Number of basis functions per direction # Number of basis functions per direction
leaf_element = test_info.element leaf_element = test_info.element
from ufl import MixedElement from ufl import MixedElement
......
...@@ -12,6 +12,7 @@ from dune.perftool.generation import (backend, ...@@ -12,6 +12,7 @@ from dune.perftool.generation import (backend,
get_backend, get_backend,
get_counted_variable, get_counted_variable,
get_global_context_value, get_global_context_value,
kernel_cached,
) )
from dune.perftool.pdelab.restriction import (Restriction, from dune.perftool.pdelab.restriction import (Restriction,
restricted_name, restricted_name,
...@@ -24,6 +25,7 @@ from dune.perftool.error import PerftoolVectorizationError ...@@ -24,6 +25,7 @@ from dune.perftool.error import PerftoolVectorizationError
from dune.perftool.options import get_form_option, get_option, set_form_option from dune.perftool.options import get_form_option, get_option, set_form_option
from dune.perftool.tools import add_to_frozendict, round_to_multiple, list_diff from dune.perftool.tools import add_to_frozendict, round_to_multiple, list_diff
from pymbolic.mapper.flop_counter import FlopCounter
from pytools import product from pytools import product
from frozendict import frozendict from frozendict import frozendict
import itertools as it import itertools as it
...@@ -136,19 +138,49 @@ def strategy_cost(strat_tuple): ...@@ -136,19 +138,49 @@ def strategy_cost(strat_tuple):
return accumulate_for_strategy(strategy, lambda sf: float(func(sf))) return accumulate_for_strategy(strategy, lambda sf: float(func(sf)))
class PrimitiveApproximateOpcounter(FlopCounter):
def map_sumfact_kernel(self, expr):
return 0
def map_tagged_variable(self, expr):
return self.map_variable(expr)
@kernel_cached
def store_operation_count(expr, count):
return count
def count_quadrature_point_operations(expr):
counter = PrimitiveApproximateOpcounter()
store_operation_count(expr, counter(expr))
def quadrature_penalized_strategy_cost(strat_tuple): def quadrature_penalized_strategy_cost(strat_tuple):
""" Implements a penalization of the cost function that accounts for
the increase in flops that occur in the quadrature loop. This needs to
somehow get a guess of how much work is done in the quadrature loop relative
to the sum factorization kernels.
"""
qp, strategy = strat_tuple
# Evaluate the original cost function. This result will be scaled by this function.
cost = strategy_cost(strat_tuple) cost = strategy_cost(strat_tuple)
qp, strategy = strat_tuple # Get the total number of Flops done in sum factorization kernels
num_qp_new = product(qp)
sf_flops = accumulate_for_strategy(strategy, lambda sf: sf.operations) sf_flops = accumulate_for_strategy(strategy, lambda sf: sf.operations)
# Get the minimal possible number of quadrature points and the actual quadrature points
num_qp_new = product(qp)
set_quadrature_points(None) set_quadrature_points(None)
num_qp_old = product(quadrature_points_per_direction()) num_qp_old = product(quadrature_points_per_direction())
set_quadrature_points(qp) set_quadrature_points(qp)
# TODO determine this # Determine the number of floating point operations per quadrature point.
ops_per_qp = 100 # This flop counting is a very crude approximation, but that is totally sufficient here.
ops_per_qp = sum(i.value for i in store_operation_count._memoize_cache.values())
# Do the actual scaling.
return float((sf_flops + ops_per_qp * num_qp_new) / (sf_flops + ops_per_qp * num_qp_old)) * cost return float((sf_flops + ops_per_qp * num_qp_new) / (sf_flops + ops_per_qp * num_qp_old)) * cost
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment