Skip to content
Snippets Groups Projects
Commit 940fee2e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Reimplement FastDG for systems

parent 2b251398
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ from dune.perftool.generation import (barrier,
globalarg,
instruction,
post_include,
preamble,
silenced_warning,
temporary_variable,
transform,
......@@ -41,6 +42,11 @@ def realize_sum_factorization_kernel(sf, **kwargs):
return _realize_sum_factorization_kernel(sf, **kwargs)
@preamble
def alias_data_array(name, data):
return "auto {} = {}.data();".format(name, data)
@generator_factory(item_tags=("sumfactkernel",),
context_tags=("kernel",),
cache_key_generator=lambda s, **kw: s.cache_key)
......@@ -70,16 +76,10 @@ def _realize_sum_factorization_kernel(sf):
insn_dep = insn_dep.union(frozenset({lp.match.Writes("input_{}".format(sf.buffer))}))
else:
direct_input_arg = "{}.data()".format(direct_input)
globalarg(direct_input_arg, dtype=np.float64)
if sf.input.element_index is None:
direct_input_temp = "{}_access".format(direct_input)
direct_input_arg = "{}_access".format(direct_input)
else:
direct_input_temp = "{}_access_comp{}".format(direct_input, sf.input.element_index)
direct_output = None
if get_option('fastdg') and sf.stage == 3:
direct_output = sf.accumvar + ".data()"
direct_input_arg = "{}_access_comp{}".format(direct_input, sf.input.element_index)
# Prepare some dim_tags/shapes for later use
ftags = ",".join(["f"] * sf.length)
......@@ -137,20 +137,19 @@ def _realize_sum_factorization_kernel(sf):
input_inames = permute_backward(input_inames, perm)
inp_shape = permute_backward(inp_shape, perm)
temporary_variable(direct_input_temp,
dtype=np.float64,
shape=inp_shape,
dim_tags=novec_ftags,
base_storage=direct_input_arg,
offset=_dof_offset(sf.input.element, sf.input.element_index),
managed=True)
silenced_warning("read_no_write({})".format(direct_input_temp))
globalarg(direct_input_arg,
dtype=np.float64,
shape=inp_shape,
dim_tags=novec_ftags,
offset=_dof_offset(sf.input.element, sf.input.element_index),
)
alias_data_array(direct_input_arg, direct_input)
if matrix.vectorized:
input_summand = prim.Call(ExplicitVCLCast(np.float64, vector_width=sf.vector_width),
(prim.Subscript(prim.Variable(direct_input_temp),
(prim.Subscript(prim.Variable(direct_input_arg),
input_inames),))
else:
input_summand = prim.Subscript(prim.Variable(direct_input_temp),
input_summand = prim.Subscript(prim.Variable(direct_input_arg),
input_inames + vec_iname)
else:
# If we did permute the order of a matrices above we also
......@@ -208,17 +207,37 @@ def _realize_sum_factorization_kernel(sf):
# In case of direct output we directly accumulate the result
# of the Sumfactorization into some global data structure.
if l == len(matrix_sequence) - 1 and direct_output is not None:
if l == len(matrix_sequence) - 1 and get_option('fastdg') and sf.stage == 3:
ft = get_global_context_value("form_type")
if sf.test_element_index is None:
direct_output = "{}_access".format(sf.accumvar)
else:
direct_output = "{}_access_comp{}".format(sf.accumvar, sf.test_element_index)
if ft == 'residual' or ft == 'jacobian_apply':
globalarg(direct_output, dtype=np.float64, shape=output_shape, dim_tags=novec_ftags)
globalarg(direct_output,
dtype=np.float64,
shape=output_shape,
dim_tags=novec_ftags,
offset=_dof_offset(sf.test_element, sf.test_element_index),
)
alias_data_array(direct_output, sf.accumvar)
assignee = prim.Subscript(prim.Variable(direct_output), output_inames)
else:
assert ft == 'jacobian'
direct_output = "{}x{}".format(direct_output, sf.trial_element_index)
rowsize = sum(tuple(s for s in _local_sizes(sf.trial_element)))
from pytools import product
manual_strides = tuple("stride:{}".format(rowsize * product(output_shape[:i])) for i in range(sf.length))
dim_tags = "{},{}".format(novec_ftags, ",".join(manual_strides))
globalarg(direct_output,
dtype=np.float64,
shape=output_shape + output_shape,
dim_tags=novec_ftags + "," + novec_ftags)
offset=rowsize * _dof_offset(sf.test_element, sf.test_element_index) + _dof_offset(sf.trial_element, sf.trial_element_index),
dim_tags=dim_tags,
)
alias_data_array(direct_output, sf.accumvar)
# TODO: It is at least questionnable, whether using the *order* of the inames in here
# for indexing is a good idea. Then again, it is hard to find an alternative.
_ansatz_inames = tuple(prim.Variable(i) for i in sf.within_inames)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment