Skip to content
Snippets Groups Projects
Commit 29b60014 authored by René Heß's avatar René Heß
Browse files

[skip ci] Autotune loop order in tensor contraction

parent 3d7cb6ac
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ from dune.codegen.loopy.transformations.remove_reductions import remove_all_redu ...@@ -9,6 +9,7 @@ from dune.codegen.loopy.transformations.remove_reductions import remove_all_redu
from dune.codegen.options import get_form_option, get_option from dune.codegen.options import get_form_option, get_option
from dune.codegen.pdelab.geometry import world_dimension from dune.codegen.pdelab.geometry import world_dimension
from dune.codegen.error import CodegenAutotuneError from dune.codegen.error import CodegenAutotuneError
from dune.codegen.sumfact.autotune import autotune_realization
def move_zero_assignment_up(kernel, move_up_inames): def move_zero_assignment_up(kernel, move_up_inames):
...@@ -169,16 +170,46 @@ def reorder_loops_in_tensor_contraction(kernel, iname_order): ...@@ -169,16 +170,46 @@ def reorder_loops_in_tensor_contraction(kernel, iname_order):
return kernel return kernel
def tensor_contraction_loop_order_generator(kernel):
dim = world_dimension()
assert dim == 3
yield kernel
indices = ['l', 'k', 'i', 'j']
import itertools
for loop_order in itertools.permutations(indices):
loop_order = ''.join(loop_order)
new_kernel = reorder_loops_in_tensor_contraction(kernel, loop_order)
yield new_kernel
def simple_autotuner(kernel_generator, signature):
# palpo TODO
from dune.codegen.options import set_option
set_option("autotune_google_benchmark", True)
kernel = next(kernel_generator)
best_cost = autotune_realization(kernel=kernel, signature=signature)
best_kernel = kernel
for kernel in kernel_generator:
cost = autotune_realization(kernel=kernel, signature=signature)
if cost < best_cost:
best_cost = cost
best_kernel = kernel
return best_kernel
def autotune_tensor_contraction_loop_order(kernel, signature):
from dune.codegen.loopy.transformations.matchfma import match_fused_multiply_add
kernel = match_fused_multiply_add(kernel)
generator = tensor_contraction_loop_order_generator(kernel)
return simple_autotuner(generator, signature)
def sumfact_performance_transformations(kernel, signature): def sumfact_performance_transformations(kernel, signature):
if kernel.name.startswith('sfimpl'): if kernel.name.startswith('sfimpl'):
# from dune.codegen.loopy.transformations.matchfma import match_fused_multiply_add # kernel = autotune_tensor_contraction_loop_order(kernel, signature)
# kernel = match_fused_multiply_add(kernel)
# kernel = reorder_loops_in_tensor_contraction(kernel, 'ijlk')
# from dune.codegen.sumfact.autotune import autotune_realization
# from dune.codegen.options import set_option
# set_option("autotune_google_benchmark", True)
# test = autotune_realization(kernel=kernel, signature=signature)
pass pass
return kernel return kernel
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment