From 69bd3edede9f8065fbe49234b078fbd1c13e3ff1 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 7 Dec 2017 14:02:44 +0100 Subject: [PATCH] Make sure that all groups are conflicting to prevent shuffled kernel realizations --- .../dune/perftool/loopy/transformations/disjointgroups.py | 6 ++++++ python/dune/perftool/pdelab/localoperator.py | 3 +++ python/dune/perftool/sumfact/basis.py | 2 +- python/dune/perftool/sumfact/realization.py | 5 +---- python/dune/perftool/sumfact/vectorization.py | 5 ----- 5 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 python/dune/perftool/loopy/transformations/disjointgroups.py diff --git a/python/dune/perftool/loopy/transformations/disjointgroups.py b/python/dune/perftool/loopy/transformations/disjointgroups.py new file mode 100644 index 00000000..2f0a64f0 --- /dev/null +++ b/python/dune/perftool/loopy/transformations/disjointgroups.py @@ -0,0 +1,6 @@ +""" A helper transformation that makes all groups conflicting """ + + +def make_groups_conflicting(knl): + groups = frozenset().union(*tuple(i.groups for i in knl.instructions)) + return knl.copy(instructions=[i.copy(conflicts_with_groups=groups - i.groups) for i in knl.instructions]) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 900ce617..23c92da1 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -529,6 +529,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True): from loopy import make_reduction_inames_unique kernel = make_reduction_inames_unique(kernel) + from dune.perftool.loopy.transformations.disjointgroups import make_groups_conflicting + kernel = make_groups_conflicting(kernel) + # Apply the transformations that were gathered during tree traversals for trafo in transformations: kernel = trafo[0](kernel, *trafo[1]) diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 889f464f..46563ee1 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -65,7 +65,7 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord): ) def __str__(self): - return "{}".format(self.coeff_func(self.restriction)) + return "{}_{}".format(self.coeff_func(self.restriction), self.element_index) def realize(self, sf, index, insn_dep): lfs = name_lfs(self.element, self.restriction, self.element_index) diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index 13acc6ac..891bbe23 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -26,9 +26,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) -from dune.perftool.sumfact.vectorization import (attach_vectorization_info, - get_all_sumfact_nodes, - ) +from dune.perftool.sumfact.vectorization import attach_vectorization_info from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.vcl import ExplicitVCLCast @@ -281,7 +279,6 @@ def _realize_sum_factorization_kernel(sf): tags=frozenset({tag}), predicates=sf.predicates, groups=frozenset({sf.group_name}), - conflicts_with_groups=frozenset([s.group_name for s in get_all_sumfact_nodes()]) - frozenset({sf.group_name}), ) }) diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 14af6638..2ba21f52 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -39,11 +39,6 @@ def _cache_vectorization_info(old, new): _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) -def get_all_sumfact_nodes(): - from dune.perftool.generation import retrieve_cache_items - return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] - - def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"): -- GitLab