Skip to content
Snippets Groups Projects
Commit f2558ce5 authored by René Heß's avatar René Heß
Browse files

[Bugfix] Adjust sumfactorization to new splitting

parent d5a9cbaa
No related branches found
No related tags found
No related merge requests found
......@@ -129,8 +129,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# If this is a gradient, we find the gradient iname
additional_inames = frozenset({})
if accterm.argument.index:
for i in accterm.argument.index._indices:
if accterm.new_indices is not None:
for i in accterm.new_indices:
if i not in visitor.dimension_indices:
from dune.perftool.pdelab.localoperator import grad_iname
additional_inames = additional_inames.union(frozenset({grad_iname(i, dim)}))
......@@ -138,7 +138,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
def emit_sumfact_kernel(i, restriction, insn_dep):
# Construct the matrix sequence for this sum factorization
a_matrices = construct_amatrix_sequence(transpose=True,
derivative=i if accterm.argument.index else None,
derivative=i if accterm.new_indices else None,
facedir=get_facedir(accterm.argument.restriction),
facemod=get_facemod(accterm.argument.restriction),
)
......@@ -219,7 +219,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
within_inames=frozenset(visitor.inames))})
inames = tuple(accum_iname((accterm.argument.restriction, restriction), mat.rows, i) for i, mat in enumerate(a_matrices))
inames = tuple(accum_iname((accterm.argument.restriction, restriction), mat.rows, i)
for i, mat in enumerate(a_matrices))
# Collect the lfs and lfs indices for the accumulate call
test_lfs = determine_accumulation_space(accterm.argument.expr, 0, measure)
......@@ -234,7 +235,8 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
if rank == 2:
# TODO the next line should get its inames from
# elsewhere. This is *NOT* robust (but works right now)
ansatz_lfs.index = flatten_index(tuple(Variable(visitor.inames[i]) for i in range(world_dimension())),
ansatz_lfs.index = flatten_index(tuple(Variable(visitor.inames[i])
for i in range(world_dimension())),
(basis_functions_per_direction(),) * dim,
order="f"
)
......@@ -258,14 +260,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# Add a sum factorization kernel that implements the multiplication
# with the test function (stage 3)
pref_pos = i if accterm.argument.index else None
pref_pos = i if accterm.new_indices else None
result, insn_dep = sum_factorization_kernel(a_matrices,
buf,
3,
insn_dep=insn_dep,
additional_inames=frozenset(visitor.inames),
preferred_position=pref_pos,
restriction=(accterm.argument.restriction, restriction),
restriction=(accterm.argument.restriction,
restriction),
direct_output=direct_output,
visitor=visitor
)
......@@ -325,7 +328,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
insn_dep = None
for restriction in jac_restrictions:
if accterm.argument.index:
if accterm.new_indices:
for i in range(world_dimension()):
insn_dep = emit_sumfact_kernel(i, restriction, insn_dep)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment