From 69bd3edede9f8065fbe49234b078fbd1c13e3ff1 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 7 Dec 2017 14:02:44 +0100
Subject: [PATCH] Make sure that all groups are conflicting to prevent shuffled
 kernel realizations

---
 .../dune/perftool/loopy/transformations/disjointgroups.py   | 6 ++++++
 python/dune/perftool/pdelab/localoperator.py                | 3 +++
 python/dune/perftool/sumfact/basis.py                       | 2 +-
 python/dune/perftool/sumfact/realization.py                 | 5 +----
 python/dune/perftool/sumfact/vectorization.py               | 5 -----
 5 files changed, 11 insertions(+), 10 deletions(-)
 create mode 100644 python/dune/perftool/loopy/transformations/disjointgroups.py

diff --git a/python/dune/perftool/loopy/transformations/disjointgroups.py b/python/dune/perftool/loopy/transformations/disjointgroups.py
new file mode 100644
index 00000000..2f0a64f0
--- /dev/null
+++ b/python/dune/perftool/loopy/transformations/disjointgroups.py
@@ -0,0 +1,6 @@
+""" A helper transformation that makes all groups conflicting """
+
+
+def make_groups_conflicting(knl):
+    groups = frozenset().union(*tuple(i.groups for i in knl.instructions))
+    return knl.copy(instructions=[i.copy(conflicts_with_groups=groups - i.groups) for i in knl.instructions])
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 900ce617..23c92da1 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -529,6 +529,9 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
     from loopy import make_reduction_inames_unique
     kernel = make_reduction_inames_unique(kernel)
 
+    from dune.perftool.loopy.transformations.disjointgroups import make_groups_conflicting
+    kernel = make_groups_conflicting(kernel)
+
     # Apply the transformations that were gathered during tree traversals
     for trafo in transformations:
         kernel = trafo[0](kernel, *trafo[1])
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 889f464f..46563ee1 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -65,7 +65,7 @@ class LFSSumfactKernelInput(SumfactKernelInputBase, ImmutableRecord):
                                  )
 
     def __str__(self):
-        return "{}".format(self.coeff_func(self.restriction))
+        return "{}_{}".format(self.coeff_func(self.restriction), self.element_index)
 
     def realize(self, sf, index, insn_dep):
         lfs = name_lfs(self.element, self.restriction, self.element_index)
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 13acc6ac..891bbe23 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -26,9 +26,7 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_backward,
                                                permute_forward,
                                                )
-from dune.perftool.sumfact.vectorization import (attach_vectorization_info,
-                                                 get_all_sumfact_nodes,
-                                                 )
+from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.sumfact.accumulation import sumfact_iname
 from dune.perftool.loopy.vcl import ExplicitVCLCast
 
@@ -281,7 +279,6 @@ def _realize_sum_factorization_kernel(sf):
                                           tags=frozenset({tag}),
                                           predicates=sf.predicates,
                                           groups=frozenset({sf.group_name}),
-                                          conflicts_with_groups=frozenset([s.group_name for s in get_all_sumfact_nodes()]) - frozenset({sf.group_name}),
                                           )
                               })
 
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 14af6638..2ba21f52 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -39,11 +39,6 @@ def _cache_vectorization_info(old, new):
 _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True)
 
 
-def get_all_sumfact_nodes():
-    from dune.perftool.generation import retrieve_cache_items
-    return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
-
-
 def attach_vectorization_info(sf):
     assert isinstance(sf, SumfactKernel)
     if get_global_context_value("dry_run"):
-- 
GitLab