From abb2766ae6e8eff812423703c322b76f9a9b75ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de> Date: Thu, 18 Apr 2019 14:22:26 +0200 Subject: [PATCH] [skip ci] Finish direct accumulation loop reordering --- .../transformations/remove_reductions.py | 5 +- .../dune/codegen/sumfact/transformations.py | 160 +++++++++++++----- 2 files changed, 124 insertions(+), 41 deletions(-) diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py index dd083951..1ccca09c 100644 --- a/python/dune/codegen/loopy/transformations/remove_reductions.py +++ b/python/dune/codegen/loopy/transformations/remove_reductions.py @@ -10,9 +10,9 @@ def remove_reduction(knl, match): for instr in lp.find_instructions(knl, match): if isinstance(instr.expression, lp.symbolic.Reduction): instructions = [] - depends_on = instr.depends_on - # Depending on this instruction + # Dependencies + depends_on = instr.depends_on depending = [] for i in knl.instructions: if instr.id in i.depends_on: @@ -44,6 +44,7 @@ def remove_reduction(knl, match): knl = knl.copy(instructions=knl.instructions + instructions) + # Restore dependencies for dep in depending: match = lp.match.Id(dep) knl = lp.add_dependency(knl, match, id_accum) diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py index d3765e55..89ed9321 100644 --- a/python/dune/codegen/sumfact/transformations.py +++ b/python/dune/codegen/sumfact/transformations.py @@ -8,7 +8,7 @@ import islpy as isl from dune.codegen.generation import (get_counted_variable, get_global_context_value, ) -from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions +from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions, remove_reduction from dune.codegen.options import get_form_option, get_option from dune.codegen.pdelab.geometry import world_dimension from dune.codegen.error import CodegenAutotuneError @@ -63,6 +63,15 @@ def _get_inames_of_reduction(instr, iname_permutation): return outer_inames, reduction_iname, inner_inames, vec_inames +def _get_iname_bound(kernel, iname): + # TODO: Not sure if that works in all cases + ldi, = kernel.get_leaf_domain_indices((iname,)) + domain = kernel.domains[ldi] + pwaff = domain.dim_max(0) + bound = lp.symbolic.pw_aff_to_expr(pwaff) + return bound + 1 + + def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): """Reorder the loop nests of a tensor contraction accumulating directly in the data structure""" dim = world_dimension() @@ -76,8 +85,6 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): from dune.codegen.sumfact.permutation import permute_backward new_iname_order = permute_backward(default_iname_order, iname_permutation) - kernel = remove_all_reductions(kernel) - for instr in kernel.instructions: # Inames used in this reduction outer_inames, reduction_iname, inner_inames, vec_inames = _get_inames_of_reduction(instr, @@ -88,34 +95,115 @@ def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation): current_inames = outer_inames + inner_inames + vec_inames current_iname_order = _current_iname_order(current_inames, new_iname_order) - if iname_permutation[-1] == dim: kernel = lp.prioritize_loops(kernel, tuple(current_iname_order)) continue - # palpo TODO - if 'haddsubst' in str(instr): - continue + if isinstance(instr.expression, lp.symbolic.Reduction): + if isinstance(instr.assignee, prim.Subscript): + assert set(inner_inames).issubset(set(i.name for i in instr.assignee.index_tuple)) + match = lp.match.Id(instr.id) + kernel = remove_reduction(kernel, match) + + lp.prioritize_loops(kernel, current_iname_order) + duplicate_inames = tuple(inner_inames) + match = lp.match.Id(instr.id + '_set_zero') + kernel = lp.duplicate_inames(kernel, duplicate_inames, match) + match_inames = tuple(lp.find_instructions(kernel, match)[0].within_inames) + set_zero_iname_order = _current_iname_order(match_inames, new_iname_order) + lp.prioritize_loops(kernel, tuple(set_zero_iname_order)) + else: + # Dependencies + match = lp.match.Id(instr.id) + depends_on = lp.find_instructions(kernel, match)[0].depends_on + depending = [] + for i in kernel.instructions: + if instr.id in i.depends_on: + depending.append(i.id) + + # Remove reduction + kernel = lp.remove_instructions(kernel, set([instr.id])) + + # Create dim_tags + dim_tags = ','.join(['f'] * len(inner_inames)) + vectorized = len(vec_inames) > 0 + if vectorized: + assert len(vec_inames) == 1 + dim_tags = dim_tags + ',vec' + + # Create shape + shape = tuple(_get_iname_bound(kernel, i) for i in inner_inames + vec_inames) + + # Update temporary_variables of this kernel + from dune.codegen.loopy.temporary import DuneTemporaryVariable + accum_variable = get_counted_variable('accum_variable') + from dune.codegen.loopy.target import dtype_floatingpoint + dtype = lp.types.NumpyType(dtype_floatingpoint()) + var = {accum_variable: DuneTemporaryVariable(accum_variable, + dtype=dtype, + shape=shape, + dim_tags=dim_tags, + managed=True)} + tv = kernel.temporary_variables.copy() + tv.update(var) + del tv[instr.assignee.name] + kernel = kernel.copy(temporary_variables=tv) + + # Set accumulation variable to zero + accum_init_inames = tuple(prim.Variable(i) for i in inner_inames) + if vectorized: + accum_init_inames = accum_init_inames + (prim.Variable(vec_inames[0]),) + assignee = prim.Subscript(prim.Variable(accum_variable,), accum_init_inames) + accum_init_id = instr.id + '_accum_init' + accum_init_instr = lp.Assignment(assignee, + 0, + within_inames=instr.within_inames, + id=accum_init_id, + depends_on=depends_on, + tags=('accum_init',), + ) + kernel = kernel.copy(instructions=kernel.instructions + [accum_init_instr]) + + # Accumulate in temporary variable + assignee = prim.Subscript(prim.Variable(accum_variable,), accum_init_inames) + expression = prim.Sum((assignee, instr.expression.expr)) + within_inames = frozenset(tuple(instr.within_inames) + instr.expression.inames) + accum_id = instr.id + '_accum' + accum_instr = lp.Assignment(assignee, + expression, + within_inames=within_inames, + id=accum_id, + depends_on=frozenset([accum_init_id]), + tags=('accum',), + ) + kernel = kernel.copy(instructions=kernel.instructions + [accum_instr]) + + # Duplicate inames and reorder + duplicate_inames = tuple(inner_inames) + for idx in [accum_init_id, accum_id]: + match = lp.match.Id(idx) + kernel = lp.duplicate_inames(kernel, duplicate_inames, match) + match_inames = lp.find_instructions(kernel, match)[0].within_inames + current_iname_order = _current_iname_order(match_inames, new_iname_order) + kernel = lp.prioritize_loops(kernel, tuple(current_iname_order)) + + # Restore dependencies + for dep in depending: + match = lp.match.Id(dep) + kernel = lp.add_dependency(kernel, match, accum_id) + + match = lp.match.Tagged('sumfact_stage3') + assign_instr, = lp.find_instructions(kernel, match) + from dune.codegen.loopy.symbolic import substitute + subst = {instr.assignee.name: assignee} + new_assign_instr = assign_instr.copy(expression=substitute(assign_instr.expression, subst), + id=assign_instr.id + '_mod') + kernel = kernel.copy(instructions=kernel.instructions + [new_assign_instr]) + kernel = lp.remove_instructions(kernel, set([assign_instr.id])) - # if 'assignment' in instr.tags or isinstance(instr.assignee, prim.Variable): - if 'assignment' in instr.tags: - # Set loop priority - lp.prioritize_loops(kernel, current_iname_order) - elif 'set_zero' in instr.tags: - # Duplicate inames and prioritize loops - duplicate_inames = tuple(i for i in inner_inames) - match = lp.match.Id(instr.id) - kernel = lp.duplicate_inames(kernel, duplicate_inames, match) - # palpo TODO prioritize! else: - # palpo TODO 2D? - # assert reduction_iname is None - - # Duplicate inames and prioritize loops - duplicate_inames = tuple(i for i in inner_inames) - match = lp.match.Id(instr.id) - kernel = lp.duplicate_inames(kernel, duplicate_inames, match) - # palpo TODO prioritize! + current_iname_order = _current_iname_order(instr.within_inames, new_iname_order) + lp.prioritize_loops(kernel, current_iname_order) return kernel @@ -173,13 +261,6 @@ def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation): dim_tags = dim_tags + ',vec' # Create shape - def _get_iname_bound(kernel, iname): - # TODO: Not sure if that works in all cases - ldi, = kernel.get_leaf_domain_indices((iname,)) - domain = kernel.domains[ldi] - pwaff = domain.dim_max(0) - bound = lp.symbolic.pw_aff_to_expr(pwaff) - return bound + 1 shape = tuple(_get_iname_bound(kernel, i) for i in inner_inames + vec_inames) # Update temporary_variables of this kernel @@ -345,6 +426,8 @@ def reorder_loops_in_tensor_contraction(kernel, iname_permutation, accum_variabl def tensor_contraction_loop_order_generator(kernel): + yield kernel, ['None'] + dim = world_dimension() identity = range(dim + 1) import itertools @@ -356,9 +439,8 @@ def tensor_contraction_loop_order_generator(kernel): new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=True) yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_True'.format(permutation)] - # palpo TODO - # new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=False) - # yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_False'.format(permutation),] + new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=False) + yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_False'.format(permutation),] def simple_autotuner(kernel_generator, signature): @@ -391,12 +473,12 @@ def sumfact_performance_transformations(kernel, signature): # dim = world_dimension() # if dim == 2: # # assert False - # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), True) - # # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), False) + # # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), True) + # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), False) # else: - # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), True) + # # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), True) # # kernel = reorder_loops_in_tensor_contraction(kernel, (1, 2, 0, 3), True) - # # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), False) + # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), False) kernel = autotune_tensor_contraction_loop_order(kernel, signature) pass -- GitLab