diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py index bfca31d69e4b0be739127a51e0213917e39fa3a4..91dadf16301525cf888c6f9a2e6d225690356a0a 100644 --- a/python/dune/perftool/loopy/symbolic.py +++ b/python/dune/perftool/loopy/symbolic.py @@ -80,7 +80,15 @@ class SumfactKernel(ImmutableRecord, prim.Variable): @property def vectorized(self): - return next(iter(a_matrices)).vectorized + return next(iter(self.a_matrices)).vectorized + + @property + def cache_key(self): + """ The cache key that can be used in generation magic, + Any two sum factorization kernels having the same cache_key + are realized simulatneously! + """ + return hash((self.a_matrices, self.restriction, self.stage, self.buffer)) class FusedMultiplyAdd(prim.Expression): diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 9f9e516ca42c47922932402bc799fa47a85cb083..41c9fdcf4d92c29f737dcea7bcb6b84d12ee303f 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -120,15 +120,21 @@ def pymbolic_coefficient_gradient(element, restriction, component, coeff_func, v # evaluation of the gradients of basis functions at quadrature # points (stage 1) if not get_global_context_value("dry_run", False): - var, insn_dep = sum_factorization_kernel(a_matrices, - buf, - 1, - preferred_position=i, + from dune.perftool.sumfact.realization import realize_sum_factorization_kernel + var, insn_dep = realize_sum_factorization_kernel(sf, insn_dep=insn_dep, - restriction=restriction, outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), direct_input=direct_input, ) +# var, insn_dep = sum_factorization_kernel(a_matrices, +# buf, +# 1, +# preferred_position=i, +# insn_dep=insn_dep, +# restriction=restriction, +# outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), +# direct_input=direct_input, +# ) else: var = sf @@ -209,15 +215,12 @@ def pymbolic_coefficient(element, restriction, component, coeff_func, visitor): # Add a sum factorization kernel that implements the evaluation of # the basis functions at quadrature points (stage 1) if not get_global_context_value("dry_run", False): - var, _ = sum_factorization_kernel(a_matrices, - buf, - 1, - preferred_position=None, - insn_dep=frozenset({Writes(inp)}), - outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), - restriction=restriction, - direct_input=direct_input, - ) + from dune.perftool.sumfact.realization import realize_sum_factorization_kernel + var, _ = realize_sum_factorization_kernel(sf, + insn_dep=frozenset({Writes(inp)}), + outshape=tuple(mat.rows for mat in a_matrices if mat.face is None), + direct_input=direct_input, + ) else: var = sf diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py new file mode 100644 index 0000000000000000000000000000000000000000..657fe129567be39eef5e67ba6707e3637517ca2b --- /dev/null +++ b/python/dune/perftool/sumfact/realization.py @@ -0,0 +1,254 @@ +""" +The code that triggers the creation of the necessary code constructs +to realize a sum factorization kernel +""" + +from dune.perftool.generation import (barrier, + dump_accumulate_timer, + generator_factory, + get_global_context_value, + instruction, + post_include, + silenced_warning, + transform, + ) +from dune.perftool.loopy.buffer import (get_buffer_temporary, + switch_base_storage, + ) +from dune.perftool.options import get_option +from dune.perftool.pdelab.signatures import assembler_routine_name +from dune.perftool.sumfact.permutation import (_sf_permutation_strategy, + _permute_backward, + _permute_forward, + ) +from dune.perftool.sumfact.vectorization import attach_vectorization_info +from dune.perftool.sumfact.sumfact import sumfact_iname + +import loopy as lp +import pymbolic.primitives as prim + + +@generator_factory(item_tags=("sumfactkernel",), + context_tags=("kernel",), + cache_key_generator=lambda s, **kw: s.cache_key) +def realize_sum_factorization_kernel(sf, insn_dep=frozenset(), outshape=None, direct_input=None, direct_output=None): + # Unify the insn_dep parameter to be a frozenset + if isinstance(insn_dep, str): + insn_dep = frozenset({insn_dep}) + assert isinstance(insn_dep, frozenset) + + if get_global_context_value("dry_run", False): + # During the dry run, we just return the kernel as passed into this + # function. After the dry run, it can be used to attach information + # about vectorization. + return sf, insn_dep +# else: +# # This is the second run: Retrieve the vectorization information +# # attached in dune.perftool.sumfact.vectorization +# sf = attach_vectorization_info(sf) + + # Prepare some dim_tags/shapes for later use + ftags = ",".join(["f"] * sf.length) + novec_ftags = ftags + ctags = ",".join(["c"] * sf.length) + vec_shape = () + if sf.vectorized: + ftags = ftags + ",vec" + ctags = ctags + ",vec" + vec_shape = (4,) + + # Measure times and count operations in c++ code + if get_option("instrumentation_level") >= 4: + timer_name = assembler_routine_name() + '_kernel' + '_stage{}'.format(sf.stage) + post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile') + dump_accumulate_timer(timer_name) + insn_dep = frozenset({instruction(code="HP_TIMER_START({});".format(timer_name), + depends_on=insn_dep, + within_inames=sf.within_inames)}) + + # Put a barrier before the sumfactorization kernel + insn_dep = frozenset({barrier(depends_on=insn_dep, + within_inames=sf.within_inames, + )}) + + # Decide in which order we want to process directions in the + # sumfactorization. A clever ordering can lead to a reduced + # complexity. This will e.g. happen at faces where we only have + # one quadratue point m_l=1 if l is the normal direction of the + # face. + # + # Rule of thumb: small m's early and large n's late. + perm = _sf_permutation_strategy(sf.a_matrices, sf.stage) + + # Permute a_matrices + a_matrices = _permute_forward(sf.a_matrices, perm) + + # Product of all matrices + for l, a_matrix in enumerate(a_matrices): + # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication + # and get inames that realize the product. + inp_shape = (a_matrix.cols,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l]) + out_shape = (a_matrix.rows,) + tuple(mat.cols for mat in a_matrices[l + 1:]) + tuple(mat.rows for mat in a_matrices[:l]) + out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape)) + vec_iname = () + if a_matrix.vectorized: + iname = sumfact_iname(4, "vec") + vec_iname = (prim.Variable(iname),) + transform(lp.tag_inames, [(iname, "vec")]) + + # A trivial reduction is implemented as a product, otherwise we run into + # a code generation corner case producing way too complicated code. This + # could be fixed upstream, but the loopy code realizing reductions is not + # trivial and the priority is kind of low. + if a_matrix.cols != 1: + k = sumfact_iname(a_matrix.cols, "red") + k_expr = prim.Variable(k) + else: + k_expr = 0 + + # Setup the input of the sum factorization kernel. In the + # first matrix multiplication this can be taken from + # * an input temporary (default) + # * a global data structure (if FastDGGridOperator is in use) + # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator) + input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) + if l == 0 and direct_input is not None: + # See comment below + input_inames = _permute_backward(input_inames, perm) + inp_shape = _permute_backward(inp_shape, perm) + + globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags) + if a_matrix.vectorized: + input_summand = prim.Call(prim.Variable("Vec4d"), + (prim.Subscript(prim.Variable(direct_input), + input_inames),)) + else: + input_summand = prim.Subscript(prim.Variable(direct_input), + input_inames + vec_iname) + else: + # If we did permute the order of a matrices above we also + # permuted the order of out_inames. Unfortunately the + # order of our input is from 0 to d-1. This means we need + # to permute _back_ to get the right coefficients. + if l == 0: + inp_shape = _permute_backward(inp_shape, perm) + input_inames = _permute_backward(input_inames, perm) + + # Get a temporary that interprets the base storage of the input + # as a column-major matrix. In later iteration of the amatrix loop + # this reinterprets the output of the previous iteration. + inp = get_buffer_temporary(sf.buffer, + shape=inp_shape + vec_shape, + dim_tags=ftags) + + # The input temporary will only be read from, so we need to silence the loopy warning + silenced_warning('read_no_write({})'.format(inp)) + + input_summand = prim.Subscript(prim.Variable(inp), + input_inames + vec_iname) + + switch_base_storage(sf.buffer) + + # Get a temporary that interprets the base storage of the output. + # + # Note: In this step the reordering of the fastest directions + # is happening. The new direction (out_inames[0]) and the + # corresponding shape (out_shape[0]) goes to the end (slowest + # direction) and everything stays column major (ftags->fortran + # style). + # + # If we are in the last step we reverse the permutation. + output_shape = tuple(out_shape[1:]) + (out_shape[0],) + if l == len(a_matrices) - 1: + output_shape = _permute_backward(output_shape, perm) + out = get_buffer_temporary(sf.buffer, + shape=output_shape + vec_shape, + dim_tags=ftags) + + # Write the matrix-matrix multiplication expression + matprod = prim.Product((prim.Subscript(prim.Variable(a_matrix.name), + (prim.Variable(out_inames[0]), k_expr) + vec_iname), + input_summand)) + + # ... which may be a reduction, if k>0 + if a_matrix.cols != 1: + matprod = lp.Reduction("sum", k, matprod) + + # Here we also move the new direction (out_inames[0]) to the + # end and reverse permutation + output_inames = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + if l == len(a_matrices) - 1: + output_inames = _permute_backward(output_inames, perm) + + # In case of direct output we directly accumulate the result + # of the Sumfactorization into some global data structure. + if l == len(a_matrices) - 1 and direct_output is not None: + ft = get_global_context_value("form_type") + if ft == 'residual' or ft == 'jacobian_apply': + globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags) + assignee = prim.Subscript(prim.Variable(direct_output), output_inames) + else: + assert ft == 'jacobian' + globalarg(direct_output, + dtype=np.float64, + shape=output_shape + output_shape, + dim_tags=novec_ftags + "," + novec_ftags) + # TODO the next line should get its inames from + # elsewhere. This is *NOT* robust (but works right + # now) + _ansatz_inames = tuple(Variable(sf.within_inames[i]) for i in range(world_dimension())) + assignee = prim.Subscript(prim.Variable(direct_output), _ansatz_inames + output_inames) + + # In case of vectorization we need to apply a horizontal add + if a_matrix.vectorized: + matprod = prim.Call(prim.Variable("horizontal_add"), + (matprod,)) + + # We need to accumulate + matprod = prim.Sum((assignee, matprod)) + else: + assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname) + + # Issue the reduction instruction that implements the multiplication + # at the same time store the instruction ID for the next instruction to depend on + insn_dep = frozenset({instruction(assignee=assignee, + expression=matprod, + forced_iname_deps=frozenset([iname for iname in out_inames]).union(sf.within_inames), + forced_iname_deps_is_final=True, + depends_on=insn_dep, + ) + }) + + # Measure times and count operations in c++ code + if get_option("instrumentation_level") >= 4: + insn_dep = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name), + depends_on=insn_dep, + within_inames=sf.within_inames)}) + if sf.stage == 1: + qp_timer_name = assembler_routine_name() + '_kernel' + '_quadratureloop' + post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile') + dump_accumulate_timer(timer_name) + insn_dep = instruction(code="HP_TIMER_START({});".format(qp_timer_name), + depends_on=insn_dep) + + if outshape is None: + assert sf.stage == 3 + outshape = tuple(mat.rows for mat in a_matrices) + + dim_tags = ",".join(['f'] * len(outshape)) + + if sf.vectorized: + outshape = outshape + vec_shape + # This is a 'bit' hacky: In stage 3 we need to return something with vectag, in stage 1 not. + if sf.stage == 1: + dim_tags = dim_tags + ",c" + else: + dim_tags = dim_tags + ",vec" + + out = get_buffer_temporary(sf.buffer, + shape=outshape, + dim_tags=dim_tags, + ) + silenced_warning('read_no_write({})'.format(out)) + + return next(iter(a_matrices)).output_to_pymbolic(out), insn_dep diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 2500a69d51d48cb8f2013bb20060c0539d7aeaaa..d46bf6b9050d360c41d770385a4ee3982cc8e6a4 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -151,7 +151,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): sf = SumfactKernel(a_matrices=a_matrices, restriction=(accterm.argument.restriction, restriction), stage=3, - preferred_position=i if accterm.new_indices else None + preferred_position=i if accterm.new_indices else None, + within_inames=frozenset(visitor.inames), ) # TODO: Move this away! @@ -287,16 +288,10 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # with the test function (stage 3) pref_pos = i if accterm.new_indices else None if not get_global_context_value("dry_run", False): - result, insn_dep = sum_factorization_kernel(a_matrices, - buf, - 3, + from dune.perftool.sumfact.realization import realize_sum_factorization_kernel + result, insn_dep = realize_sum_factorization_kernel(sf, insn_dep=insn_dep, - additional_inames=frozenset(visitor.inames), - preferred_position=pref_pos, - restriction=(accterm.argument.restriction, - restriction), direct_output=direct_output, - visitor=visitor ) else: result = sf diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index c7662135cc92161b7317a93c793a1af613023bb1..60dadf8af59b8d608f389c8b70cc4ccce6df19db 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -34,7 +34,7 @@ def attach_vectorization_info(sf): def no_vectorization(sumfacts): for sf in sumfacts: - _cache_vectorization_info(sf, sf) + _cache_vectorization_info(sf, sf.copy(buffer=get_counted_variable("buffer"))) def decide_stage_vectorization_strategy(sumfacts, stage, restriction):