diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 78e04406645c1beac7beea90e8d74764093db598..6b4b42e6402c8a7265820150ca729a6da7bd8eb6 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -64,14 +64,16 @@ def start_sumfactorization(element, container, restriction, component): lfs = name_lfs(element, restriction, component) coeff = pymbolic_coefficient(container, lfs, basisiname) assignee = Subscript(Variable(inp), (Variable(basisiname),)) - instruction(assignee = assignee, - expression = coeff, - ) + from dune.perftool.loopy.stages import stage_insn + insn_dep = instruction(assignee = assignee, + expression = coeff, + depends_on = frozenset({stage_insn(0)}), + ) - return sum_factorization_kernel(a_matrices, inp, "buffer") + return sum_factorization_kernel(a_matrices, inp, "buffer", insn_dep) -def sum_factorization_kernel(a_matrices, inp, buffer, stage=0): +def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep): """ Calculate a sum factorization matrix product. @@ -88,10 +90,6 @@ def sum_factorization_kernel(a_matrices, inp, buffer, stage=0): buffer: A string identifying the flip flop buffer in use for intermediate results. """ - # Get the stage instruction - from dune.perftool.loopy.stages import stage_insn - insn_dep = stage_insn(stage) - for l, a_matrix in enumerate(a_matrices): # Get a temporary that interprets the base storage of the input # as a column-major matrix. In later iteration of the amatrix loop