Skip to content
Snippets Groups Projects
Commit c4579aaf authored by Dominic Kempf's avatar Dominic Kempf
Browse files

More fixes and honor position_priority in cost model

parent 0c68888c
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ from dune.perftool.error import PerftoolError
from dune.perftool.options import get_option
from dune.perftool.tools import add_to_frozendict,round_to_multiple
from pytools import product
from frozendict import frozendict
import itertools as it
import loopy as lp
......@@ -69,9 +70,11 @@ def explicit_costfunction(sf):
vertical = int(vertical)
if sf.horizontal_width == horizontal and sf.vertical_width == vertical:
return 1
# Penalize position mapping
penalty = sum(abs(sf.kernels[i].position_priority - i) if sf.kernels[i].position_priority is not None else 0 for i in range(sf.length))
return 1 + penalty
else:
return 2
return 1000000000000
def strategy_cost(strategy):
......@@ -79,7 +82,16 @@ def strategy_cost(strategy):
set_quadrature_points(qp)
func = get_backend(interface="vectorization_strategy",
selector=lambda: get_option("vectorization_strategy"))
return sum(float(func(sf)) for sf in strategy.values())
keys = set(sf.cache_key for sf in strategy.values())
# Sum over all the sum factorization kernels in the realization
score = 0.0
for sf in strategy.values():
if sf.cache_key in keys:
score = score + float(func(sf))
keys.discard(sf.cache_key)
return score
def stringify_vectorization_strategy(strategy):
......@@ -164,7 +176,7 @@ def vectorization_opportunity_generator(sumfacts, width):
#
quad_points = [quadrature_points_per_direction()]
if True or get_option("vectorization_allow_quadrature_changes"):
if get_option("vectorization_allow_quadrature_changes"):
sf = next(iter(sumfacts))
depth = 1
while depth <= width:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment