diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index c7642cec3b4746248907d48f2b9606c0307fa4be..0ab498368e77f906dbca7793fcc973ab40c49f69 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -5,6 +5,7 @@ from dune.perftool.generation.cache import (cached, no_caching, retrieve_cache_items, delete_cache_items, + inspect_generator, ) from dune.perftool.generation.cpp import (base_class, diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index c4255edf0873c67956be787c2219446c987deb94..e488c8717fd71f7fca836a68c282dc72ba7fa4d8 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -8,8 +8,8 @@ from __future__ import absolute_import from pytools import memoize -# have one cache the module level. It is easier than handing around an instance of it. -_cache = {} +# Store a global list of generator functions +_generators = [] def _freeze(data): @@ -106,6 +106,9 @@ class _RegisteredFunction(object): self.func = func self.cache_key_generator = cache_key_generator + # Initialize the memoization cache + self._memoize_cache = {} + # Append the current context item tags to the given ones from dune.perftool.generation.context import get_context_tags if 'item_tags' in kwargs: @@ -115,20 +118,24 @@ class _RegisteredFunction(object): self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs) + # Register this generator function + _generators.append(self) + def __call__(self, *args, **kwargs): # Get the cache key from the given arguments - cache_args = self.cache_key_generator(*args, **kwargs) + cache_key = self.cache_key_generator(*args, **kwargs) # Make sure that all keyword arguments have vanished from the cache_args - assert (lambda *a, **k: len(k) == 0)(cache_args) - cache_key = (self, _freeze(self.cache_key_generator(*args, **kwargs))) + assert (lambda *a, **k: len(k) == 0)(cache_key) + # check whether we have a cache hit - if cache_key in _cache: + if cache_key in self._memoize_cache: # and return the result depending on the cache item type - return _cache[cache_key].content + return self._memoize_cache[cache_key].content else: # evaluate the original function and wrap it in a cache item - citem = self.itemtype(self.func(*args, **kwargs)) - _cache[cache_key] = citem + val = self.func(*args, **kwargs) + citem = self.itemtype(val) + self._memoize_cache[cache_key] = citem return citem.content @@ -211,13 +218,11 @@ class _ConditionDict(dict): return i in self.tags -def _filter_cache_items(condition): - return {k: v for k, v in _cache.items() if eval(condition, _ConditionDict(v.tags))} +def _filter_cache_items(gen, condition): + return {k: v for k, v in gen._memoize_cache.items() if eval(condition, _ConditionDict(v.tags))} def retrieve_cache_items(condition=True, make_generable=False): - choice = _filter_cache_items(condition).values() - def as_generable(content): if make_generable: from cgen import Generable, Line @@ -230,12 +235,18 @@ def retrieve_cache_items(condition=True, make_generable=False): return content # First yield all those items that are not sorted - for item in choice: - if not item.counted: - yield as_generable(item.content) + for gen in _generators: + choice = _filter_cache_items(gen, condition).values() + for item in choice: + if not item.counted: + yield as_generable(item.content) # And now the sorted ones - for item in sorted([i for i in choice if i.counted], key=lambda i: i.content[0]): + counted_ones = [] + for gen in _generators: + counted_ones.extend(filter(lambda i: i.counted, _filter_cache_items(gen, condition).values())) + + for item in sorted(counted_ones, key=lambda i: i.content[0]): from collections import Iterable if isinstance(item.content[1], Iterable) and not isinstance(item.content[1], str): for l in item.content[1]: @@ -249,5 +260,16 @@ def delete_cache_items(condition=True, keep=False): if not keep: condition = "not ({})".format(condition) - global _cache - _cache = _filter_cache_items(condition) + for gen in _generators: + gen._memoize_cache = _filter_cache_items(gen, condition) + + +def inspect_generator(gen): + # Must be a generator function + assert(isinstance(gen, _RegisteredFunction)) + + print("Inspecting generator function {}".format(gen.func.func_name)) + for k, v in gen._memoize_cache.items(): + print(" args: {}".format(k)) + print(" val: {}".format(v.content)) + print(" tags: {}\n".format(v.tags))