diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 28791c17bdae29dd30d5718cac4cfdb829b535b4..0ee51952d2c49adc09299996e430d6675da62524 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -394,6 +394,20 @@ def _sf_permutation_strategy(a_matrices, stage): return perm +def _permute_forward(t, perm): + tmp = [] + for pos in perm: + tmp.append(t[pos]) + return tuple(tmp) + + +def _permute_backward(t, perm): + tmp = [None]*len(t) + for i, pos in enumerate(perm): + tmp[pos] = t[i] + return tuple(tmp) + + @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0))) def sum_factorization_kernel(a_matrices, buf, @@ -437,6 +451,15 @@ def sum_factorization_kernel(a_matrices, Note: In the code below the transformation step is directly done in the reduction instruction by adapting the assignee! + It can make sense to permute the order of directions. If you have + a small m_l (e.g. stage 1 on faces) it is better to do direction l + first. This can be done permuting: + + - The order of the A matrices. + - Permuting the input tensor. + - Permuting the output tensor (this assures that the directions of + the output tensor are again ordered from 0 to d-1). + Arguments: ---------- a_matrices: An iterable of AMatrix instances @@ -490,38 +513,10 @@ def sum_factorization_kernel(a_matrices, # face. # # Rule of thumb: small m's early and large n's late. - - # palpo TODO - if stage==3 and outshape!=None: - from IPython import embed; embed(); import sys; sys.exit("Error message") - - if stage==1 or stage==3: - # perm = range(len(a_matrices)) - perm = _sf_permutation_strategy(a_matrices, stage) - else: - perm = range(len(a_matrices)) - - # # palpo TODO - # print("## PALPO") - # shape = [(mat.rows,mat.cols) for mat in a_matrices] - # print(shape) + perm = _sf_permutation_strategy(a_matrices, stage) # Permute a_matrices - new_a_matrices = [] - for pos in perm: - new_a_matrices.append(a_matrices[pos]) - a_matrices = tuple(new_a_matrices) - - # # palpo TODO - # shape = [(mat.rows,mat.cols) for mat in a_matrices] - # print(shape) - - # new_a_matrices = [None]*len(a_matrices) - # for i, pos in enumerate(perm): - # new_a_matrices[pos] = a_matrices[i] - # a_matrices = tuple(new_a_matrices) - # shape = [(mat.rows,mat.cols) for mat in a_matrices] - # print(shape) + a_matrices = _permute_forward(a_matrices, perm) # Product of all matrices for l, a_matrix in enumerate(a_matrices): @@ -565,32 +560,18 @@ def sum_factorization_kernel(a_matrices, input_summand = prim.Subscript(prim.Variable(direct_input), palpo + vec_iname) else: + # If we did permute the order of a matrices above we also + # permuted the order of out_inames. Unfortunately the + # order of our input is from 0 to d-1. This means we need + # to permute _back_ to get the right coefficients. + input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) + if l == 0: + inp_shape = _permute_backward(inp_shape, perm) + input_inames = _permute_backward(input_inames, perm) + # Get a temporary that interprets the base storage of the input # as a column-major matrix. In later iteration of the amatrix loop # this reinterprets the output of the previous iteration. - palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) - if l==0: - tmp_perm = [None]*len(inp_shape) - for i, pos in enumerate(perm): - tmp_perm[pos] = inp_shape[i] - inp_shape = tuple(tmp_perm) - - tmp_perm = [None]*len(palpo) - for i, pos in enumerate(perm): - tmp_perm[pos] = palpo[i] - palpo = tuple(tmp_perm) - - # tmp_perm = [] - # for pos in perm: - # tmp_perm.append(inp_shape[pos]) - # inp_shape = tuple(tmp_perm) - - # palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) - # tmp_perm = [] - # for pos in perm: - # tmp_perm.append(palpo[pos]) - # palpo = tuple(tmp_perm) - inp = get_buffer_temporary(buf, shape=inp_shape + vec_shape, dim_tags=ftags) @@ -599,7 +580,7 @@ def sum_factorization_kernel(a_matrices, silenced_warning('read_no_write({})'.format(inp)) input_summand = prim.Subscript(prim.Variable(inp), - palpo + vec_iname) + input_inames + vec_iname) switch_base_storage(buf) @@ -610,32 +591,16 @@ def sum_factorization_kernel(a_matrices, # corresponding shape (out_shape[0]) goes to the end (slowest # direction) and everything stays column major (ftags->fortran # style). - # if False: + # + # If we are in the last step we reverse the permutation. + output_shape = tuple(out_shape[1:]) + (out_shape[0],) if l == len(a_matrices)-1: - out_shape = tuple(out_shape[1:]) + (out_shape[0],) - tmp_perm = [None]*len(out_shape) - for i, pos in enumerate(perm): - tmp_perm[pos] = out_shape[i] - out_shape = tuple(tmp_perm) - - # tmp_perm = [] - # for pos in perm: - # tmp_perm.append(out_shape[pos]) - # out_shape = tuple(tmp_perm) - - - out = get_buffer_temporary(buf, - shape=out_shape + vec_shape, - dim_tags=ftags) - else: - out = get_buffer_temporary(buf, - shape=tuple(out_shape[1:]) + (out_shape[0],) + vec_shape, - dim_tags=ftags) + output_shape = _permute_backward(output_shape, perm) + out = get_buffer_temporary(buf, + shape=output_shape + vec_shape, + dim_tags=ftags) # Write the matrix-matrix multiplication expression - # matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), - # (prim.Variable(i), k_expr) + vec_iname), - # input_summand)) matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), (prim.Variable(out_inames[0]), k_expr) + vec_iname), input_summand)) @@ -644,24 +609,12 @@ def sum_factorization_kernel(a_matrices, if a_matrix.cols != 1: matprod = lp.Reduction("sum", k, matprod) - # Here we also move the new direction (out_inames[0]) to the end - # if False: + # Here we also move the new direction (out_inames[0]) to the + # end and reverse permutation + output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) if l == len(a_matrices)-1: - palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) - tmp_perm = [None]*len(palpo) - for i, pos in enumerate(perm): - tmp_perm[pos] = palpo[i] - palpo = tuple(tmp_perm) - - # palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) - # tmp_perm = [] - # for pos in perm: - # tmp_perm.append(palpo[pos]) - # palpo = tuple(tmp_perm) - - assignee = prim.Subscript(prim.Variable(out), palpo + vec_iname) - else: - assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname) + output_inames = _permute_backward(output_inames, perm) + assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname) # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on