diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 788ba19253c58f1c0d14b46cfd67a3859f35dde6..2af48bd16417a7f3c62dd45df7c0473ad6b3ce86 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -1,7 +1,9 @@ """ This module provides the memoization infrastructure for code generating functions. """ -from dune.perftool.generation.context import get_global_context_value +from dune.perftool.generation.context import (get_global_context_value, + global_context, + ) from dune.perftool.generation.counter import get_counter # Store a global list of generator functions @@ -54,6 +56,7 @@ class _RegisteredFunction(object): on_store=lambda x: x, item_tags=(), context_tags=(), + **kwargs ): self.func = func self.cache_key_generator = cache_key_generator @@ -61,6 +64,7 @@ class _RegisteredFunction(object): self.on_store = on_store self.item_tags = item_tags self.context_tags = context_tags + self.kwargs = kwargs # Initialize the memoization cache self._memoize_cache = {} @@ -81,7 +85,7 @@ class _RegisteredFunction(object): else: return self._memoize_cache[key] - def __call__(self, *args, **kwargs): + def call(self, *args, **kwargs): # Get the cache key from the given arguments cache_key = self.cache_key_generator(*args, **kwargs) @@ -105,6 +109,10 @@ class _RegisteredFunction(object): # Return the result for immediate usage return self._get_content(cache_key) + def __call__(self, *args, **kwargs): + with global_context(**self.kwargs): + return self.call(*args, **kwargs) + def generator_factory(**factory_kwargs): """ A function decorator factory @@ -172,24 +180,24 @@ cached = generator_factory(item_tags=("default_cached",)) class _ConditionDict(dict): - def __init__(self, tags): - dict.__init__(self) - self.tags = tags - - def __getitem__(self, i): - # If we do not add these special cases the dictionary will return False - # when we execute the following code: - # - # eval ("True", _ConditionDict(v.tags) - # - # But in this case we want to return True! A normal dictionary would not attempt - # to replace "True" if "True" is not a key. The _ConditionDict obviously has no - # such concerns ;). - if i == "True": - return True - if i == "False": - return False - return i in self.tags + def __init__(self, tags): + dict.__init__(self) + self.tags = tags + + def __getitem__(self, i): + # If we do not add these special cases the dictionary will return False + # when we execute the following code: + # + # eval ("True", _ConditionDict(v.tags) + # + # But in this case we want to return True! A normal dictionary would not attempt + # to replace "True" if "True" is not a key. The _ConditionDict obviously has no + # such concerns ;). + if i == "True": + return True + if i == "False": + return False + return i in self.tags def _filter_cache_items(gen, condition): diff --git a/python/test/dune/perftool/generation/test_cache.py b/python/test/dune/perftool/generation/test_cache.py index 3d97048032d407b41d16699d11bc229bc9b31df1..84775fbfbc504a08bf4adfa477e3d92d116ed055 100644 --- a/python/test/dune/perftool/generation/test_cache.py +++ b/python/test/dune/perftool/generation/test_cache.py @@ -253,3 +253,11 @@ def test_multiple_kernels(): k2, = retrieve_cache_items("k2") assert k2 == "bla" + + @preamble(kernel="k3") + def pre3(): + return "foo" + + pre3() + k3, = retrieve_cache_items("k3") + assert k3 == "foo"