diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 2af48bd16417a7f3c62dd45df7c0473ad6b3ce86..9f7aba9be5c24e06a7f16e98624ba04b08eefa3e 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -110,7 +110,11 @@ class _RegisteredFunction(object): return self._get_content(cache_key) def __call__(self, *args, **kwargs): - with global_context(**self.kwargs): + additional_kw = {k: kwargs[k] for k in kwargs if k in self.context_tags} + for k, v in self.kwargs.items(): + additional_kw[k] = v + kwargs = {k: kwargs[k] for k in kwargs if k not in self.context_tags} + with global_context(**additional_kw): return self.call(*args, **kwargs) diff --git a/python/test/dune/perftool/generation/test_cache.py b/python/test/dune/perftool/generation/test_cache.py index fa86a9b5c8c744628acad7f4061213eb8a264b2e..47b9bc7f5e261718aaaa25a6b6a4b3c904d35f83 100644 --- a/python/test/dune/perftool/generation/test_cache.py +++ b/python/test/dune/perftool/generation/test_cache.py @@ -290,3 +290,40 @@ def test_multiple_kernels_3(): assert k4 == "bar" delete_cache_items() + + +def test_multiple_kernels_4(): + gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) + + with global_context(kernel="k1"): + gen("foo") + + with global_context(kernel="k2"): + gen("bar") + + assert len([i for i in retrieve_cache_items("tag")]) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "foo" + + k2, = retrieve_cache_items("k2") + assert k2 == "bar" + + delete_cache_items() + + +def test_multiple_kernels_5(): + gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) + + gen("foo", kernel="k1") + gen("bar", kernel="k2") + + assert len([i for i in retrieve_cache_items("tag")]) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "foo" + + k2, = retrieve_cache_items("k2") + assert k2 == "bar" + + delete_cache_items()