Skip to content
Snippets Groups Projects
Commit 92aba8d7 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Correctly tag stage 1 kernels in jacobians

We havent got a nonlinear sumfact test case so far, so this one did not break.
test coming soon.
parent 5887d2a3
No related branches found
No related tags found
No related merge requests found
......@@ -261,6 +261,10 @@ def _realize_sum_factorization_kernel(sf):
else:
assignee = prim.Subscript(prim.Variable(out), output_inames + vec_iname)
tag = "sumfact_stage{}".format(sf.stage)
if sf.stage == 3:
tag = "{}_{}".format(tag, "_".join(sf.within_inames))
# Issue the reduction instruction that implements the multiplication
# at the same time store the instruction ID for the next instruction to depend on
insn_dep = frozenset({instruction(assignee=assignee,
......@@ -268,7 +272,7 @@ def _realize_sum_factorization_kernel(sf):
forced_iname_deps=frozenset([iname for iname in out_inames]).union(frozenset(sf.within_inames)),
forced_iname_deps_is_final=True,
depends_on=insn_dep,
tags=frozenset({"sumfact_stage{}_within{}".format(sf.stage, "_".join(sf.within_inames))}),
tags=frozenset({tag}),
predicates=sf.predicates,
groups=frozenset({sf.group_name}),
conflicts_with_groups=frozenset([s.group_name for s in get_all_sumfact_nodes()]) - frozenset({sf.group_name}),
......@@ -278,7 +282,7 @@ def _realize_sum_factorization_kernel(sf):
# Measure times and count operations in c++ code
if get_option("instrumentation_level") >= 4:
stop_insn = frozenset({instruction(code="HP_TIMER_STOP({});".format(timer_name),
depends_on=frozenset({lp.match.Tagged("sumfact_stage{}_within{}".format(sf.stage, "_".join(sf.within_inames)))}),
depends_on=frozenset({tag}),
within_inames=frozenset(sf.within_inames))})
if sf.stage == 1:
qp_timer_name = assembler_routine_name() + '_kernel' + '_quadratureloop'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment