diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py index bda7bd1328d457a4dfc0b3be054043335d150063..7ca626b37d93b4f603e4465b9ed3561998894997 100644 --- a/python/dune/perftool/sumfact/realization.py +++ b/python/dune/perftool/sumfact/realization.py @@ -26,6 +26,9 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy, permute_backward, permute_forward, ) +from dune.perftool.sumfact.vectorization import (attach_vectorization_info, + get_all_sumfact_nodes, + ) from dune.perftool.sumfact.accumulation import sumfact_iname from dune.perftool.loopy.vcl import ExplicitVCLCast @@ -267,6 +270,8 @@ def _realize_sum_factorization_kernel(sf): depends_on=insn_dep, tags=frozenset({"sumfact_stage{}_within{}".format(sf.stage, "_".join(sf.within_inames))}), predicates=sf.predicates, + groups=frozenset({sf.group_name}), + conflicts_with_groups=frozenset([s.group_name for s in get_all_sumfact_nodes()]) - frozenset({sf.group_name}), ) }) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index fb89d95f4c3a600a45f94940023028ea371b4ccb..a5c4eb0630227769821be5382f156b2a03b2760d 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -170,6 +170,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): """ return (self.input, self.restriction, self.accumvar, self.trial_element_index) + @property + def group_name(self): + return "sfgroup_{}_{}_{}_{}".format(self.input, self.restriction, self.accumvar, self.trial_element_index) + # # Some convenience methods to extract information about the sum factorization kernel # @@ -451,6 +455,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def input_key(self): return tuple(k.input_key for k in self.kernels) + @property + def group_name(self): + return "_".join(k.group_name for k in self.kernels) + @property def length(self): return self.kernels[0].length diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 3a0447945be53aea13c3f1188a78241e883dede3..b2dca33d0c4bfcf82414e6ec58e0b8c66d1dfa42 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -11,6 +11,7 @@ from dune.perftool.generation import (generator_factory, from dune.perftool.pdelab.restriction import (Restriction, restricted_name, ) +from dune.perftool.sumfact.tabulation import BasisTabulationMatrixArray from dune.perftool.error import PerftoolError from dune.perftool.options import get_option @@ -28,6 +29,11 @@ def _cache_vectorization_info(old, new): _collect_sumfact_nodes = generator_factory(item_tags=("sumfactnodes", "dryrundata"), context_tags="kernel", no_deco=True) +def get_all_sumfact_nodes(): + from dune.perftool.generation import retrieve_cache_items + return [i for i in retrieve_cache_items("kernel_default and sumfactnodes")] + + def attach_vectorization_info(sf): assert isinstance(sf, SumfactKernel) if get_global_context_value("dry_run"):