Skip to content
Snippets Groups Projects
Commit dc799b48 authored by René Heß's avatar René Heß
Browse files

Accumulation possible for fastdg

parent a5ae37b7
No related branches found
No related tags found
No related merge requests found
......@@ -139,29 +139,29 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
# Get the vectorization info. If this happens during the dry run, we get dummies
from dune.perftool.sumfact.vectorization import get_vectorization_info
a_matrices, buffer, input, index, padding = get_vectorization_info(a_matrices, restriction)
a_matrices, buf, inp, index, padding = get_vectorization_info(a_matrices, restriction)
# Flip flop buffers for sumfactorization
shape = (product(mat.cols for mat in a_matrices),)
if index is not None:
shape = shape + (4,)
initialize_buffer(buffer,
initialize_buffer(buf,
base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
num=2
).get_temporary(shape=shape,
name=input,
name=inp,
)
# Setup the input!
setup_theta(input, element, restriction, component, index)
setup_theta(inp, element, restriction, component, index)
# Add a sum factorization kernel that implements the evaluation of
# the basis functions at quadrature points (stage 1)
var, _ = sum_factorization_kernel(a_matrices,
buffer,
buf,
1,
preferred_position=None,
insn_dep=frozenset({Writes(input)}),
insn_dep=frozenset({Writes(inp)}),
outshape=tuple(mat.rows for mat in a_matrices if mat.rows != 1),
restriction=restriction,
)
......
......@@ -244,9 +244,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
(maybe_wrap_subscript(result, prim.Variable(iname)),),
)
# In the case of FastDGGridOperator we can write directly into the resiudal/jacobi
if get_option('fastdg'):
ft = get_global_context_value("form_type")
if ft=='residual':
accum = accum + ".data()"
size = basis_functions_per_direction() ** world_dimension()
globalarg(accum, dtype=np.float64, shape=(size,), managed=False)
assignee = prim.Subscript(prim.Variable(accum), (test_lfs.index,))
......@@ -259,9 +261,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
)
else:
assert ft=='jacobian'
# palpo TODO: think about it
accum = accum + ".data()"
size = basis_functions_per_direction() ** world_dimension()
globalarg(accum, dtype=np.float64, shape=(size, size), managed=False)
globalarg(accum, dtype=np.float64, shape=(size, size), managed=True)
assignee = prim.Subscript(prim.Variable(accum), (ansatz_lfs.index, test_lfs.index))
expression = prim.Sum((assignee,result))
instruction(assignee=assignee,
......@@ -270,19 +272,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
forced_iname_deps_is_final=True,
depends_on=insn_dep,
)
# expr = Call(PDELabAccumulationFunction(accum, rank),
# (ansatz_lfs.get_args() +
# test_lfs.get_args() +
# (result,)
# )
# )
# instruction(assignees=(),
# expression=expr,
# forced_iname_deps=frozenset(inames + visitor.inames + vecinames),
# forced_iname_deps_is_final=True,
# depends_on=insn_dep,
# )
# Default: Generate accumulation instructions
else:
expr = Call(PDELabAccumulationFunction(accum, rank),
(ansatz_lfs.get_args() +
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment