Skip to content
Snippets Groups Projects
Commit aec6e23b authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Allow syntax like '@preamble(kernel=...)'

parent e559ebec
No related branches found
No related tags found
No related merge requests found
""" This module provides the memoization infrastructure for code """ This module provides the memoization infrastructure for code
generating functions. generating functions.
""" """
from dune.perftool.generation.context import get_global_context_value from dune.perftool.generation.context import (get_global_context_value,
global_context,
)
from dune.perftool.generation.counter import get_counter from dune.perftool.generation.counter import get_counter
# Store a global list of generator functions # Store a global list of generator functions
...@@ -54,6 +56,7 @@ class _RegisteredFunction(object): ...@@ -54,6 +56,7 @@ class _RegisteredFunction(object):
on_store=lambda x: x, on_store=lambda x: x,
item_tags=(), item_tags=(),
context_tags=(), context_tags=(),
**kwargs
): ):
self.func = func self.func = func
self.cache_key_generator = cache_key_generator self.cache_key_generator = cache_key_generator
...@@ -61,6 +64,7 @@ class _RegisteredFunction(object): ...@@ -61,6 +64,7 @@ class _RegisteredFunction(object):
self.on_store = on_store self.on_store = on_store
self.item_tags = item_tags self.item_tags = item_tags
self.context_tags = context_tags self.context_tags = context_tags
self.kwargs = kwargs
# Initialize the memoization cache # Initialize the memoization cache
self._memoize_cache = {} self._memoize_cache = {}
...@@ -81,7 +85,7 @@ class _RegisteredFunction(object): ...@@ -81,7 +85,7 @@ class _RegisteredFunction(object):
else: else:
return self._memoize_cache[key] return self._memoize_cache[key]
def __call__(self, *args, **kwargs): def call(self, *args, **kwargs):
# Get the cache key from the given arguments # Get the cache key from the given arguments
cache_key = self.cache_key_generator(*args, **kwargs) cache_key = self.cache_key_generator(*args, **kwargs)
...@@ -105,6 +109,10 @@ class _RegisteredFunction(object): ...@@ -105,6 +109,10 @@ class _RegisteredFunction(object):
# Return the result for immediate usage # Return the result for immediate usage
return self._get_content(cache_key) return self._get_content(cache_key)
def __call__(self, *args, **kwargs):
with global_context(**self.kwargs):
return self.call(*args, **kwargs)
def generator_factory(**factory_kwargs): def generator_factory(**factory_kwargs):
""" A function decorator factory """ A function decorator factory
...@@ -172,24 +180,24 @@ cached = generator_factory(item_tags=("default_cached",)) ...@@ -172,24 +180,24 @@ cached = generator_factory(item_tags=("default_cached",))
class _ConditionDict(dict): class _ConditionDict(dict):
def __init__(self, tags): def __init__(self, tags):
dict.__init__(self) dict.__init__(self)
self.tags = tags self.tags = tags
def __getitem__(self, i): def __getitem__(self, i):
# If we do not add these special cases the dictionary will return False # If we do not add these special cases the dictionary will return False
# when we execute the following code: # when we execute the following code:
# #
# eval ("True", _ConditionDict(v.tags) # eval ("True", _ConditionDict(v.tags)
# #
# But in this case we want to return True! A normal dictionary would not attempt # But in this case we want to return True! A normal dictionary would not attempt
# to replace "True" if "True" is not a key. The _ConditionDict obviously has no # to replace "True" if "True" is not a key. The _ConditionDict obviously has no
# such concerns ;). # such concerns ;).
if i == "True": if i == "True":
return True return True
if i == "False": if i == "False":
return False return False
return i in self.tags return i in self.tags
def _filter_cache_items(gen, condition): def _filter_cache_items(gen, condition):
......
...@@ -253,3 +253,11 @@ def test_multiple_kernels(): ...@@ -253,3 +253,11 @@ def test_multiple_kernels():
k2, = retrieve_cache_items("k2") k2, = retrieve_cache_items("k2")
assert k2 == "bla" assert k2 == "bla"
@preamble(kernel="k3")
def pre3():
return "foo"
pre3()
k3, = retrieve_cache_items("k3")
assert k3 == "foo"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment