From 5b20e6c3983cf9608bd104810d98acf99dc6307b Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Fri, 1 Sep 2017 15:46:34 +0200 Subject: [PATCH] Make sure that sum factorization kernels are not interleaved By using loopys group mechanism. Each sum factorization kernel defines a group that conflicts with all other sum factorization groups. Conflicts: python/dune/perftool/sumfact/realization.py python/dune/perftool/sumfact/vectorization.py --- python/dune/perftool/sumfact/realization.py | 5 +++++ python/dune/perftool/sumfact/symbolic.py | 8 ++++++++ python/dune/perftool/sumfact/vectorization.py | 6 ++++++ 3 files changed, 19 insertions(+) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index bda7bd13..7ca626b3 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -26,6 +26,9 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) +from dune.perftool.sumfact.vectorization import (attach_vectorization_info, + get_all_sumfact_nodes, + ) from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.vcl import ExplicitVCLCast @@ -267,6 +270,8 @@ def _realize_sum_factorization_kernel(sf): depends_on=insn_dep, tags=frozenset({"sumfact_stage{}_within{}".format(sf.stage, "_".join(sf.within_inames))}), predicates=sf.predicates, + groups=frozenset({sf.group_name}), + conflicts_with_groups=frozenset([s.group_name for s in get_all_sumfact_nodes()]) - frozenset({sf.group_name}), ) }) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index fb89d95f..a5c4eb06 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -170,6 +170,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): """ return (self.input, self.restriction, self.accumvar, self.trial_element_index) + @property + def group_name(self): + return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index) + # # Some convenience methods to extract information about the sum factorization kernel # @@ -451,6 +455,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def input_key(self): return tuple(k.input_key for k in self.kernels) + @property + def group_name(self): + return "_".join(k.group_name for k in self.kernels) + @property def length(self): return self.kernels[0].length diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 3a044794..b2dca33d 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -11,6 +11,7 @@ from dune.perftool.generation import (generator_factory, from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) +from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray from dune.perftool.error import PerftoolError from dune.perftool.options import get_option @@ -28,6 +29,11 @@ def _cache_vectorization_info(old, new): _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) +def get_all_sumfact_nodes(): + from dune.perftool.generation import retrieve_cache_items + return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] + + def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): -- GitLab