diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index fb36e575e1be5656dd9c2979f4dab8f6888851d0..e5c3b0b4dc5011c087b3ea97ad6e67be556949da 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -489,8 +489,7 @@ def extract_kernel_from_cache(tag): # Preprocess some instruction! from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] - for insn in instructions: - expand_sumfact_kernels(insn) + expand_sumfact_kernels(instructions) filter_sumfact_instructions() # Now extract regular loopy kernel components diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index 691aa1dbc114ed47ac9cebc39b6a8692ded1ad69..eb44a4c7247277ee57264b0ac72758c102580356 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -74,20 +74,21 @@ def find_sumfact(expr): return HasSumfactMapper()(expr) -def expand_sumfact_kernels(insn): - if isinstance(insn, (lp.Assignment, lp.CallInstruction)): - replace = {} - deps = [] - for sumf in find_sumfact(insn.expression): - var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, sumf.insn_dep, sumf.additional_inames) - replace[sumf] = prim.Variable(var) - deps.append(dep) - - if replace: - built_instruction(insn.copy(expression=substitute(insn.expression, replace), - depends_on=frozenset(*deps) - ) - ) +def expand_sumfact_kernels(insns): + for insn in insns: + if isinstance(insn, (lp.Assignment, lp.CallInstruction)): + replace = {} + deps = [] + for sumf in find_sumfact(insn.expression): + var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, sumf.insn_dep, sumf.additional_inames) + replace[sumf] = prim.Variable(var) + deps.append(dep) + + if replace: + built_instruction(insn.copy(expression=substitute(insn.expression, replace), + depends_on=frozenset(*deps) + ) + ) def filter_sumfact_instructions():