diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 58b67988800a8be0cd2200ce340862b81b150764..e4c3fb592dd3eb4e32a51e44fc377d49e969aa10 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -119,6 +119,9 @@ class _RegisteredFunction(object): # Return the result for immediate usage return self._get_content(cache_key) + def remove_by_value(self, val): + self._memoize_cache = {k:v for k, v in self._memoize_cache.items() if v != val} + def generator_factory(**factory_kwargs): """ A function decorator factory diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index c1455f35f71a8cb82103df1662870cc39b4992a3..8610952bbcf299f948b50f5227fde362f4601119 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -488,10 +488,8 @@ def generate_kernel(integrals): def extract_kernel_from_cache(tag): # Preprocess some instruction! - from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions - instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] - expand_sumfact_kernels(instructions) - filter_sumfact_instructions() + from dune.perftool.sumfact.sumfact import expand_sumfact_kernels + expand_sumfact_kernels(tag) # Now extract regular loopy kernel components from dune.perftool.loopy.target import DuneTarget diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index e3042690f55007928716e0baeb806fa4e0ba9030..c6bc625490cb24361ed198a84da3c978eee652b2 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -16,6 +16,7 @@ from dune.perftool.generation import (backend, globalarg, iname, instruction, + retrieve_cache_items, silenced_warning, temporary_variable, transform, @@ -139,6 +140,12 @@ def default_resolution(insns): depends_on=frozenset(*deps) ) ) + if isinstance(insn, lp.Assignment): + from dune.perftool.generation.loopy import expr_instruction_impl + expr_instruction_impl.remove_by_value(insn) + if isinstance(insn, lp.CallInstruction): + from dune.perftool.generation.loopy import call_instruction_impl + call_instruction_impl.remove_by_value(insn) def apply_sumfact_grad_vectorization(insns, stage): @@ -153,7 +160,7 @@ def apply_sumfact_grad_vectorization(insns, stage): # Now apply some heuristics when to vectorize... if len(set(sumfact_kernels)) < 3: - default_resolution(insns) + return else: # Vectorize!!! sumfact_kernels = sorted(sumfact_kernels, key=lambda s: s.preferred_interleaving_position) @@ -167,6 +174,7 @@ def apply_sumfact_grad_vectorization(insns, stage): # Maybe initialize the input buffer if sumfact_kernels[0].setup_method: + assert stage == 1 shape = product(mat.cols for mat in sumfact_kernels[0].a_matrices) shape = (shape, 4) initialize_buffer(buffer, base_storage_size=4 * shape[0]) @@ -178,6 +186,7 @@ def apply_sumfact_grad_vectorization(insns, stage): func, args = sumf.setup_method insn_dep = insn_dep.union({func(inp, *args, additional_indices=(i,))}) else: + assert stage == 3 # No setup method defined. We need to make sure the input is correctly setup shape = tuple(mat.cols for mat in sumfact_kernels[0].a_matrices) + (4,) initialize_buffer(buffer, base_storage_size=4 * shape[0]) @@ -189,9 +198,11 @@ def apply_sumfact_grad_vectorization(insns, stage): for insn in insns: if isinstance(insn, lp.Assignment): if get_pymbolic_basename(insn.assignee) == sumf.input_temporary: - built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)))) + built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)), + id=insn.id + "__{}".format(i))) + insn_dep = insn_dep.union(insn.depends_on) from dune.perftool.generation.loopy import expr_instruction_impl - expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if v.id != insn.id} + expr_instruction_impl.remove_by_value(insn) # Determine the joined AMatrix large_a_matrices = [] @@ -225,19 +236,19 @@ def apply_sumfact_grad_vectorization(insns, stage): depends_on=dep, )) -def expand_sumfact_kernels(insns): - if get_option("vectorize_grads"): - apply_sumfact_grad_vectorization(insns, 1) - apply_sumfact_grad_vectorization(insns, 3) - else: - default_resolution(insns) + if isinstance(insn, lp.Assignment): + from dune.perftool.generation.loopy import expr_instruction_impl + expr_instruction_impl.remove_by_value(insn) + if isinstance(insn, lp.CallInstruction): + from dune.perftool.generation.loopy import call_instruction_impl + call_instruction_impl.remove_by_value(insn) -def filter_sumfact_instructions(): - """ Remove all instructions that contain a SumfactKernel node """ - from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl - expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} - call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)} +def expand_sumfact_kernels(tag): + if get_option("vectorize_grads"): + apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 1) + apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 3) + default_resolution(retrieve_cache_items("{} and instruction".format(tag))) @iname @@ -389,7 +400,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): (Subscript(result, tuple(Variable(i) for i in inames)),) ) ) - instruction(assignees=(), expression=expr, forced_iname_deps=frozenset(inames + visitor.inames),