Skip to content
Snippets Groups Projects
Commit abb2766a authored by René Heß's avatar René Heß
Browse files

[skip ci] Finish direct accumulation loop reordering

parent 99691a29
No related branches found
No related tags found
No related merge requests found
...@@ -10,9 +10,9 @@ def remove_reduction(knl, match): ...@@ -10,9 +10,9 @@ def remove_reduction(knl, match):
for instr in lp.find_instructions(knl, match): for instr in lp.find_instructions(knl, match):
if isinstance(instr.expression, lp.symbolic.Reduction): if isinstance(instr.expression, lp.symbolic.Reduction):
instructions = [] instructions = []
depends_on = instr.depends_on
# Depending on this instruction # Dependencies
depends_on = instr.depends_on
depending = [] depending = []
for i in knl.instructions: for i in knl.instructions:
if instr.id in i.depends_on: if instr.id in i.depends_on:
...@@ -44,6 +44,7 @@ def remove_reduction(knl, match): ...@@ -44,6 +44,7 @@ def remove_reduction(knl, match):
knl = knl.copy(instructions=knl.instructions + instructions) knl = knl.copy(instructions=knl.instructions + instructions)
# Restore dependencies
for dep in depending: for dep in depending:
match = lp.match.Id(dep) match = lp.match.Id(dep)
knl = lp.add_dependency(knl, match, id_accum) knl = lp.add_dependency(knl, match, id_accum)
......
...@@ -8,7 +8,7 @@ import islpy as isl ...@@ -8,7 +8,7 @@ import islpy as isl
from dune.codegen.generation import (get_counted_variable, from dune.codegen.generation import (get_counted_variable,
get_global_context_value, get_global_context_value,
) )
from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions, remove_reduction
from dune.codegen.options import get_form_option, get_option from dune.codegen.options import get_form_option, get_option
from dune.codegen.pdelab.geometry import world_dimension from dune.codegen.pdelab.geometry import world_dimension
from dune.codegen.error import CodegenAutotuneError from dune.codegen.error import CodegenAutotuneError
...@@ -63,6 +63,15 @@ def _get_inames_of_reduction(instr, iname_permutation): ...@@ -63,6 +63,15 @@ def _get_inames_of_reduction(instr, iname_permutation):
return outer_inames, reduction_iname, inner_inames, vec_inames return outer_inames, reduction_iname, inner_inames, vec_inames
def _get_iname_bound(kernel, iname):
# TODO: Not sure if that works in all cases
ldi, = kernel.get_leaf_domain_indices((iname,))
domain = kernel.domains[ldi]
pwaff = domain.dim_max(0)
bound = lp.symbolic.pw_aff_to_expr(pwaff)
return bound + 1
def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation):
"""Reorder the loop nests of a tensor contraction accumulating directly in the data structure""" """Reorder the loop nests of a tensor contraction accumulating directly in the data structure"""
dim = world_dimension() dim = world_dimension()
...@@ -76,8 +85,6 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): ...@@ -76,8 +85,6 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation):
from dune.codegen.sumfact.permutation import permute_backward from dune.codegen.sumfact.permutation import permute_backward
new_iname_order = permute_backward(default_iname_order, iname_permutation) new_iname_order = permute_backward(default_iname_order, iname_permutation)
kernel = remove_all_reductions(kernel)
for instr in kernel.instructions: for instr in kernel.instructions:
# Inames used in this reduction # Inames used in this reduction
outer_inames, reduction_iname, inner_inames, vec_inames = _get_inames_of_reduction(instr, outer_inames, reduction_iname, inner_inames, vec_inames = _get_inames_of_reduction(instr,
...@@ -88,34 +95,115 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): ...@@ -88,34 +95,115 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation):
current_inames = outer_inames + inner_inames + vec_inames current_inames = outer_inames + inner_inames + vec_inames
current_iname_order = _current_iname_order(current_inames, current_iname_order = _current_iname_order(current_inames,
new_iname_order) new_iname_order)
if iname_permutation[-1] == dim: if iname_permutation[-1] == dim:
kernel = lp.prioritize_loops(kernel, tuple(current_iname_order)) kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
continue continue
# palpo TODO if isinstance(instr.expression, lp.symbolic.Reduction):
if 'haddsubst' in str(instr): if isinstance(instr.assignee, prim.Subscript):
continue assert set(inner_inames).issubset(set(i.name for i in instr.assignee.index_tuple))
match = lp.match.Id(instr.id)
kernel = remove_reduction(kernel, match)
lp.prioritize_loops(kernel, current_iname_order)
duplicate_inames = tuple(inner_inames)
match = lp.match.Id(instr.id + '_set_zero')
kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
match_inames = tuple(lp.find_instructions(kernel, match)[0].within_inames)
set_zero_iname_order = _current_iname_order(match_inames, new_iname_order)
lp.prioritize_loops(kernel, tuple(set_zero_iname_order))
else:
# Dependencies
match = lp.match.Id(instr.id)
depends_on = lp.find_instructions(kernel, match)[0].depends_on
depending = []
for i in kernel.instructions:
if instr.id in i.depends_on:
depending.append(i.id)
# Remove reduction
kernel = lp.remove_instructions(kernel, set([instr.id]))
# Create dim_tags
dim_tags = ','.join(['f'] * len(inner_inames))
vectorized = len(vec_inames) > 0
if vectorized:
assert len(vec_inames) == 1
dim_tags = dim_tags + ',vec'
# Create shape
shape = tuple(_get_iname_bound(kernel, i) for i in inner_inames + vec_inames)
# Update temporary_variables of this kernel
from dune.codegen.loopy.temporary import DuneTemporaryVariable
accum_variable = get_counted_variable('accum_variable')
from dune.codegen.loopy.target import dtype_floatingpoint
dtype = lp.types.NumpyType(dtype_floatingpoint())
var = {accum_variable: DuneTemporaryVariable(accum_variable,
dtype=dtype,
shape=shape,
dim_tags=dim_tags,
managed=True)}
tv = kernel.temporary_variables.copy()
tv.update(var)
del tv[instr.assignee.name]
kernel = kernel.copy(temporary_variables=tv)
# Set accumulation variable to zero
accum_init_inames = tuple(prim.Variable(i) for i in inner_inames)
if vectorized:
accum_init_inames = accum_init_inames + (prim.Variable(vec_inames[0]),)
assignee = prim.Subscript(prim.Variable(accum_variable,), accum_init_inames)
accum_init_id = instr.id + '_accum_init'
accum_init_instr = lp.Assignment(assignee,
0,
within_inames=instr.within_inames,
id=accum_init_id,
depends_on=depends_on,
tags=('accum_init',),
)
kernel = kernel.copy(instructions=kernel.instructions + [accum_init_instr])
# Accumulate in temporary variable
assignee = prim.Subscript(prim.Variable(accum_variable,), accum_init_inames)
expression = prim.Sum((assignee, instr.expression.expr))
within_inames = frozenset(tuple(instr.within_inames) + instr.expression.inames)
accum_id = instr.id + '_accum'
accum_instr = lp.Assignment(assignee,
expression,
within_inames=within_inames,
id=accum_id,
depends_on=frozenset([accum_init_id]),
tags=('accum',),
)
kernel = kernel.copy(instructions=kernel.instructions + [accum_instr])
# Duplicate inames and reorder
duplicate_inames = tuple(inner_inames)
for idx in [accum_init_id, accum_id]:
match = lp.match.Id(idx)
kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
match_inames = lp.find_instructions(kernel, match)[0].within_inames
current_iname_order = _current_iname_order(match_inames, new_iname_order)
kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
# Restore dependencies
for dep in depending:
match = lp.match.Id(dep)
kernel = lp.add_dependency(kernel, match, accum_id)
match = lp.match.Tagged('sumfact_stage3')
assign_instr, = lp.find_instructions(kernel, match)
from dune.codegen.loopy.symbolic import substitute
subst = {instr.assignee.name: assignee}
new_assign_instr = assign_instr.copy(expression=substitute(assign_instr.expression, subst),
id=assign_instr.id + '_mod')
kernel = kernel.copy(instructions=kernel.instructions + [new_assign_instr])
kernel = lp.remove_instructions(kernel, set([assign_instr.id]))
# if 'assignment' in instr.tags or isinstance(instr.assignee, prim.Variable):
if 'assignment' in instr.tags:
# Set loop priority
lp.prioritize_loops(kernel, current_iname_order)
elif 'set_zero' in instr.tags:
# Duplicate inames and prioritize loops
duplicate_inames = tuple(i for i in inner_inames)
match = lp.match.Id(instr.id)
kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
# palpo TODO prioritize!
else: else:
# palpo TODO 2D? current_iname_order = _current_iname_order(instr.within_inames, new_iname_order)
# assert reduction_iname is None lp.prioritize_loops(kernel, current_iname_order)
# Duplicate inames and prioritize loops
duplicate_inames = tuple(i for i in inner_inames)
match = lp.match.Id(instr.id)
kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
# palpo TODO prioritize!
return kernel return kernel
...@@ -173,13 +261,6 @@ def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation): ...@@ -173,13 +261,6 @@ def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation):
dim_tags = dim_tags + ',vec' dim_tags = dim_tags + ',vec'
# Create shape # Create shape
def _get_iname_bound(kernel, iname):
# TODO: Not sure if that works in all cases
ldi, = kernel.get_leaf_domain_indices((iname,))
domain = kernel.domains[ldi]
pwaff = domain.dim_max(0)
bound = lp.symbolic.pw_aff_to_expr(pwaff)
return bound + 1
shape = tuple(_get_iname_bound(kernel, i) for i in inner_inames + vec_inames) shape = tuple(_get_iname_bound(kernel, i) for i in inner_inames + vec_inames)
# Update temporary_variables of this kernel # Update temporary_variables of this kernel
...@@ -345,6 +426,8 @@ def reorder_loops_in_tensor_contraction(kernel, iname_permutation, accum_variabl ...@@ -345,6 +426,8 @@ def reorder_loops_in_tensor_contraction(kernel, iname_permutation, accum_variabl
def tensor_contraction_loop_order_generator(kernel): def tensor_contraction_loop_order_generator(kernel):
yield kernel, ['None']
dim = world_dimension() dim = world_dimension()
identity = range(dim + 1) identity = range(dim + 1)
import itertools import itertools
...@@ -356,9 +439,8 @@ def tensor_contraction_loop_order_generator(kernel): ...@@ -356,9 +439,8 @@ def tensor_contraction_loop_order_generator(kernel):
new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=True) new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=True)
yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_True'.format(permutation)] yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_True'.format(permutation)]
# palpo TODO new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=False)
# new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=False) yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_False'.format(permutation),]
# yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_False'.format(permutation),]
def simple_autotuner(kernel_generator, signature): def simple_autotuner(kernel_generator, signature):
...@@ -391,12 +473,12 @@ def sumfact_performance_transformations(kernel, signature): ...@@ -391,12 +473,12 @@ def sumfact_performance_transformations(kernel, signature):
# dim = world_dimension() # dim = world_dimension()
# if dim == 2: # if dim == 2:
# # assert False # # assert False
# kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), True) # # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), True)
# # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), False) # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), False)
# else: # else:
# kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), True) # # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), True)
# # kernel = reorder_loops_in_tensor_contraction(kernel, (1, 2, 0, 3), True) # # kernel = reorder_loops_in_tensor_contraction(kernel, (1, 2, 0, 3), True)
# # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), False) # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), False)
kernel = autotune_tensor_contraction_loop_order(kernel, signature) kernel = autotune_tensor_contraction_loop_order(kernel, signature)
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment