Skip to content
Snippets Groups Projects
Commit 726613da authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Add a cache restoring context manager

parent 8b68d125
No related branches found
No related tags found
No related merge requests found
...@@ -49,6 +49,7 @@ from dune.perftool.generation.loopy import (barrier, ...@@ -49,6 +49,7 @@ from dune.perftool.generation.loopy import (barrier,
valuearg, valuearg,
) )
from dune.perftool.generation.context import (global_context, from dune.perftool.generation.context import (cache_restoring,
global_context,
get_global_context_value, get_global_context_value,
) )
...@@ -268,7 +268,7 @@ def delete_cache_items(condition=True, keep=False): ...@@ -268,7 +268,7 @@ def delete_cache_items(condition=True, keep=False):
gen._memoize_cache = _filter_cache_items(gen, condition) gen._memoize_cache = _filter_cache_items(gen, condition)
def retrieve_cache_functions(condition=True): def retrieve_cache_functions(condition="True"):
return [g.func for g in _generators if eval(condition, _ConditionDict(g.item_tags))] return [g.func for g in _generators if eval(condition, _ConditionDict(g.item_tags))]
......
...@@ -31,3 +31,22 @@ def global_context(**kwargs): ...@@ -31,3 +31,22 @@ def global_context(**kwargs):
def get_global_context_value(key, default=None): def get_global_context_value(key, default=None):
return _global_context_cache.get(key, default) return _global_context_cache.get(key, default)
class _CacheRestoringContext(object):
def __enter__(self):
from dune.perftool.generation.cache import _generators as g
self.cache = {}
for i in g:
self.cache[i] = {}
for k, v in i._memoize_cache.items():
self.cache[i][k] = v
def __exit__(self, exc_type, exc_value, traceback):
for i, c in self.cache.items():
for k, v in c.items():
i._memoize_cache[k] = v
def cache_restoring():
return _CacheRestoringContext()
""" Autotuning for sum factorization kernels """ """ Autotuning for sum factorization kernels """
from dune.perftool.generation import delete_cache_items from dune.perftool.generation import cache_restoring, delete_cache_items
from dune.perftool.loopy.target import DuneTarget from dune.perftool.loopy.target import DuneTarget
from dune.perftool.sumfact.realization import realize_sumfact_kernel_function from dune.perftool.sumfact.realization import realize_sumfact_kernel_function
from dune.perftool.options import get_option from dune.perftool.options import get_option
...@@ -169,7 +169,9 @@ def autotune_realization(sf): ...@@ -169,7 +169,9 @@ def autotune_realization(sf):
logname = "{}.log".format(name) logname = "{}.log".format(name)
# Generate and compile a benchmark program # Generate and compile a benchmark program
generate_standalone_code(sf, filename, logname) with cache_restoring():
generate_standalone_code(sf, filename, logname)
ret = subprocess.call(compiler_invocation(name, filename)) ret = subprocess.call(compiler_invocation(name, filename))
assert ret == 0 assert ret == 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment