diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 42fd320b4a1b9ceeabeb07a726edfd3300709ee2..96ef00127c5848c00866e99a17018be338fe2bed 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -9,6 +9,7 @@ from dune.perftool.generation import (barrier, globalarg, instruction, post_include, + preamble, silenced_warning, temporary_variable, transform, @@ -41,6 +42,11 @@ def realize_sum_factorization_kernel(sf, **kwargs): return _realize_sum_factorization_kernel(sf, **kwargs) +@preamble +def alias_data_array(name, data): + return "auto {} = {}.data();".format(name, data) + + @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda s, **kw: s.cache_key) @@ -70,16 +76,10 @@ def _realize_sum_factorization_kernel(sf): insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))})) else: - direct_input_arg = "{}.data()".format(direct_input) - globalarg(direct_input_arg, dtype=np.float64) if sf.input.element_index is None: - direct_input_temp = "{}_access".format(direct_input) + direct_input_arg = "{}_access".format(direct_input) else: - direct_input_temp = "{}_access_comp{}".format(direct_input, sf.input.element_index) - - direct_output = None - if get_option('fastdg') and sf.stage == 3: - direct_output = sf.accumvar + ".data()" + direct_input_arg = "{}_access_comp{}".format(direct_input, sf.input.element_index) # Prepare some dim_tags/shapes for later use ftags = ",".join(["f"] * sf.length) @@ -137,20 +137,19 @@ def _realize_sum_factorization_kernel(sf): input_inames = permute_backward(input_inames, perm) inp_shape = permute_backward(inp_shape, perm) - temporary_variable(direct_input_temp, - dtype=np.float64, - shape=inp_shape, - dim_tags=novec_ftags, - base_storage=direct_input_arg, - offset=_dof_offset(sf.input.element, sf.input.element_index), - managed=True) - silenced_warning("read_no_write({})".format(direct_input_temp)) + globalarg(direct_input_arg, + dtype=np.float64, + shape=inp_shape, + dim_tags=novec_ftags, + offset=_dof_offset(sf.input.element, sf.input.element_index), + ) + alias_data_array(direct_input_arg, direct_input) if matrix.vectorized: input_summand = prim.Call(ExplicitVCLCast(np.float64, vector_width=sf.vector_width), - (prim.Subscript(prim.Variable(direct_input_temp), + (prim.Subscript(prim.Variable(direct_input_arg), input_inames),)) else: - input_summand = prim.Subscript(prim.Variable(direct_input_temp), + input_summand = prim.Subscript(prim.Variable(direct_input_arg), input_inames + vec_iname) else: # If we did permute the order of a matrices above we also @@ -208,17 +207,37 @@ def _realize_sum_factorization_kernel(sf): # In case of direct output we directly accumulate the result # of the Sumfactorization into some global data structure. - if l == len(matrix_sequence) - 1 and direct_output is not None: + if l == len(matrix_sequence) - 1 and get_option('fastdg') and sf.stage == 3: ft = get_global_context_value("form_type") + if sf.test_element_index is None: + direct_output = "{}_access".format(sf.accumvar) + else: + direct_output = "{}_access_comp{}".format(sf.accumvar, sf.test_element_index) if ft == 'residual' or ft == 'jacobian_apply': - globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags) + globalarg(direct_output, + dtype=np.float64, + shape=output_shape, + dim_tags=novec_ftags, + offset=_dof_offset(sf.test_element, sf.test_element_index), + ) + alias_data_array(direct_output, sf.accumvar) + assignee = prim.Subscript(prim.Variable(direct_output), output_inames) else: assert ft == 'jacobian' + + direct_output = "{}x{}".format(direct_output, sf.trial_element_index) + rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element))) + from pytools import product + manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length)) + dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides)) globalarg(direct_output, dtype=np.float64, shape=output_shape + output_shape, - dim_tags=novec_ftags + "," + novec_ftags) + offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index), + dim_tags=dim_tags, + ) + alias_data_array(direct_output, sf.accumvar) # TODO: It is at least questionnable, whether using the *order* of the inames in here # for indexing is a good idea. Then again, it is hard to find an alternative. _ansatz_inames = tuple(prim.Variable(i) for i in sf.within_inames)