diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 82aa38d8982d4b66b753e1f1d07dbaa9e009d341..eee3bb36fe8bbde7c728c65f6e906d4e157e32a5 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -339,12 +339,37 @@ def sum_factorization_kernel(a_matrices, restriction=0, direct_input=None, ): - """Calculate a sum factorization tensor product. + """Create a sum factorization kernel - Y = A_{d-1}*...*A_0*X + Sum factorization can be written as - where X is the input tensor and Y is the output variable. This is - done using matrices and reinterpreting the data structures. + Y = R_{d-1} (A_{d-1} * ... * R_0 (A_0 * X)...) + + with: + - X: Input rank d tensor of dimension n_0 x ... x n_{d-1} + - Y: Output rank d tensor of dimension m_0 x ... x m_{d-1} + - A_l: Values of 1D basis evaluations at quadrature points in l + direction, matrix of dimension m_l x n_l + - R_l: Transformation operator that permutes the underlying data + vector of the rank d tensor in such a way that the fastest + direction gets the slowest direction + + In the l'th step we have the following setup: + - A_l: Matrix of dimensions m_l x n_l + - X_l: Rank d tensor of dimensions n_l x ... x n_{d-1} x m_0 x ... x m_{l-1} + - R_l: Transformation operator + + Looking at the indizes the following will happen: + X --> [n_l,...,n_{d-1},m_0,...,m_{l-1}] + A_l * X --> [m_l,n_l] * [n_l, ...] = [m_l,n_{l+1},...,n_{d-1},m_0,...,m_{l-1}] + R_l (A_l*X) --> [n_{l+1},...,n_{d-1},m_0,...,m_{l-1}] + + So the multiplication with A_l is reduction over one index and the + transformation brings the next reduction index in the fastest + position. + + Note: In the code below the transformation step is directly done + in the reduction instruction by adapting the assignee! Arguments: ---------- @@ -366,6 +391,7 @@ def sum_factorization_kernel(a_matrices, restriction: Restriction for faces values. direct_input: Global data structure containing input for sumfactorization (e.g. when using FastDGGridOperator). + """ if get_global_context_value("dry_run", False): return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset() @@ -396,12 +422,8 @@ def sum_factorization_kernel(a_matrices, for l, a_matrix in enumerate(a_matrices): # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication # and get inames that realize the product. - # inp_shape = (a_matrix.cols, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:])) - # out_shape = (a_matrix.rows, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:])) inp_shape = (a_matrix.cols,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l]) out_shape = (a_matrix.rows,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l]) - # i = sumfact_iname(out_shape[0], "row") - # j = sumfact_iname(out_shape[1], "col") out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) vec_iname = () if a_matrix.vectorized: @@ -449,8 +471,13 @@ def sum_factorization_kernel(a_matrices, switch_base_storage(buf) - # Get a temporary that interprets the base storage of the output - # as row-major matrix + # Get a temporary that interprets the base storage of the output. + # + # Note: In this step the reordering of the fastest directions + # is happening. The new direction (out_inames[0]) and the + # corresponding shape (out_shape[0]) goes to the end (slowest + # direction) and everything stays column major (ftags->fortran + # style). out = get_buffer_temporary(buf, shape=tuple(out_shape[1:]) + (out_shape[0],) + vec_shape, dim_tags=ftags) @@ -467,10 +494,11 @@ def sum_factorization_kernel(a_matrices, if a_matrix.cols != 1: matprod = lp.Reduction("sum", k, matprod) + # Here we also move the new direction (out_inames[0]) to the end + assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname) # Issue the reduction instruction that implements the multiplication # at the same time store the instruction ID for the next instruction to depend on - assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname) insn_dep = frozenset({instruction(assignee=assignee, expression=matprod, forced_iname_deps=frozenset([iname for iname in out_inames]).union(additional_inames),