Skip to content
Snippets Groups Projects
Commit 1b02ba97 authored by René Heß's avatar René Heß
Browse files

Fix another bug in jacobian splitting

parent 4c41c5a7
No related branches found
No related tags found
No related merge requests found
......@@ -159,8 +159,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
index = ()
vectag = frozenset()
base_storage_size = product(max(mat.rows, mat.cols) for mat in a_matrices)
temp = initialize_buffer(buf,
base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
base_storage_size=base_storage_size,
num=2
).get_temporary(shape=shape,
dim_tags=dim_tags,
......@@ -168,9 +169,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
)
# Those input fields, that are padded need to be set to zero
# in order to do a horizontal_add lateron
# in order to do a horizontal_add later on
for pad in padding:
instruction(assignee=prim.Subscript(prim.Variable(temp), tuple(Variable(i) for i in quadrature_inames()) + (pad,)),
assignee = prim.Subscript(prim.Variable(temp),
tuple(Variable(i) for i in quadrature_inames()) + (pad,))
instruction(assignee=assignee,
expression=0,
forced_iname_deps=frozenset(quadrature_inames() + visitor.inames),
forced_iname_deps_is_final=True,
......@@ -316,13 +319,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
depends_on=insn_dep,
)
# Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
# Mark the transformation that moves the quadrature loop
# inside the trialfunction loops for application
transform(nest_quadrature_loops, visitor.inames)
return insn_dep
# Extract the restrictions on argument-1:
jac_restrictions = frozenset(tuple(ma.restriction for ma in extract_modified_arguments(accterm.term, argnumber=1)))
jac_restrictions = frozenset(tuple(ma.restriction for ma in
extract_modified_arguments(accterm.term, argnumber=1, do_index=True)))
if not jac_restrictions:
jac_restrictions = frozenset({0})
......@@ -461,9 +466,9 @@ def sum_factorization_kernel(a_matrices,
It can make sense to permute the order of directions. If you have
a small m_l (e.g. stage 1 on faces) it is better to do direction l
first. This can be done permuting:
first. This can be done by:
- The order of the A matrices.
- Permuting the order of the A matrices.
- Permuting the input tensor.
- Permuting the output tensor (this assures that the directions of
the output tensor are again ordered from 0 to d-1).
......
......@@ -132,31 +132,18 @@ def split_into_accumulation_terms(expr):
# 5) Further split according to trial function in jacobian terms
if all_jacobian_args:
jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=False)
for jac_arg in jac_args:
# 5.1) Cut the expression to this ansatz function
indexed_jac_args = extract_modified_arguments(replace_expr, argnumber=1, do_index=True)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
for ma in jac_args}
replacement[jac_arg.expr] = jac_arg.expr
jac_expr = replace_expression(replace_expr, replacemap=replacement)
# 5.2) Propagate indexed zeros to simplify expression
jac_expr = zero_propagation(jac_expr)
if ma.restriction != restriction else ma.expr
for ma in indexed_jac_args}
indexed_jac_args = extract_modified_arguments(jac_expr, argnumber=1, do_index=True)
for restriction in (Restriction.NONE, Restriction.POSITIVE, Restriction.NEGATIVE):
replacement = {ma.expr: Zero(shape=ma.expr.ufl_shape,
free_indices=ma.expr.ufl_free_indices,
index_dimensions=ma.expr.ufl_index_dimensions)
if ma.restriction != restriction else ma.expr
for ma in indexed_jac_args}
jac_accum_expr = replace_expression(jac_expr, replacemap=replacement)
jac_expr = replace_expression(replace_expr, replacemap=replacement)
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_accum_expr, test_arg, indexmap, newi))
if not isinstance(jac_expr, Zero):
ret.append(AccumulationTerm(jac_expr, test_arg, indexmap, newi))
else:
if not isinstance(replace_expr, Zero):
ret.append(AccumulationTerm(replace_expr, test_arg, indexmap, newi))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment