diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 216ce95c0f7a06ab99cedd7210f01b740b69999b..4ef8cea937be9db092ea4dfb08f973cab0216753 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -339,12 +339,12 @@ def sum_factorization_kernel(a_matrices, restriction=0, direct_input=None, ): - """ - Calculate a sum factorization matrix product. + """Calculate a sum factorization tensor product. Y = A_{d-1}*...*A_0*X - where X is the input matrix and Y is the output variable. + where X is the input tensor and Y is the output variable. This is + done using matrices and reinterpreting the data structures. Arguments: ---------- @@ -353,10 +353,19 @@ def sum_factorization_kernel(a_matrices, Order of application is from 0 up. buf: A string identifying the flip flop buffer in use for intermediate results. The memory is expected to be - pre-initialized with the input. - insn_dep: an instruction ID that the first issued instruction + pre-initialized with the input or you have to provide + direct_input (FastDGGridOperator). + insn_dep: An instruction ID that the first issued instruction should depend upon. All following ones will depend on each other. + additional_inames: Instructions will be executed within those + inames (needed for stage 3 in jacobians). + preferred_position: Will be used in the dry run to order kernels + when doing vectorization e.g. (dx u,dy u,dz u, u). + outshape: Shape of the output. + restriction: Restriction for faces values. + direct_input: Global data structure containing input for + sumfactorization (e.g. when using FastDGGridOperator). """ if get_global_context_value("dry_run", False): return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset() @@ -369,6 +378,7 @@ def sum_factorization_kernel(a_matrices, ctags = ctags + ",vec" vec_shape = (4,) + # Measure times and count operations in c++ code if get_option("instrumentation_level") >= 4: timer_name = assembler_routine_name() + '_kernel' + '_stage{}'.format(stage) post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile') @@ -377,10 +387,12 @@ def sum_factorization_kernel(a_matrices, depends_on=insn_dep, within_inames=additional_inames)}) + # Put a barrier before the sumfactorization kernel insn_dep = frozenset({barrier(depends_on=insn_dep, within_inames=additional_inames, )}) + # Product of all matrices for l, a_matrix in enumerate(a_matrices): # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication # and get inames that realize the product. @@ -404,16 +416,17 @@ def sum_factorization_kernel(a_matrices, else: k_expr = 0 - # Setup the input of the sum factorization kernel. In the first matrix multiplication - # this can be taken from + # Setup the input of the sum factorization kernel. In the + # first matrix multiplication this can be taken from # * an input temporary (default) # * a global data structure (if FastDGGridOperator is in use) # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator) if l == 0 and direct_input is not None: globalarg(direct_input, dtype=np.float64, shape=inp_shape) if a_matrix.vectorized: - input_summand = prim.Call(prim.Variable("Vec4d"), (prim.Subscript(prim.Variable(direct_input), - (k_expr, prim.Variable(j))),)) + input_summand = prim.Call(prim.Variable("Vec4d"), + (prim.Subscript(prim.Variable(direct_input), + (k_expr, prim.Variable(j))),)) else: input_summand = prim.Subscript(prim.Variable(direct_input), (k_expr, prim.Variable(j)) + vec_iname) @@ -440,11 +453,11 @@ def sum_factorization_kernel(a_matrices, dim_tags=ctags) # Write the matrix-matrix multiplication expression - matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), (prim.Variable(i), k_expr) + vec_iname), - input_summand - )) + matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), + (prim.Variable(i), k_expr) + vec_iname), + input_summand)) + # ... which may be a reduction, if k>0 if a_matrix.cols != 1: - # ... which may be a reduction, if k>0 matprod = lp.Reduction("sum", k, matprod) # Issue the reduction instruction that implements the multiplication @@ -458,6 +471,7 @@ def sum_factorization_kernel(a_matrices, ) }) + # Measure times and count operations in c++ code if get_option("instrumentation_level") >= 4: insn_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name), depends_on=insn_dep,