Skip to content
Snippets Groups Projects
Commit bfcf525e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Remove cache_context and CacheItemTypes

They are concepts that are not really necessary with local
storage, that I do not really want to explain to other people.
parent c7cb7bd0
No related branches found
No related tags found
No related merge requests found
...@@ -59,13 +59,12 @@ def compile_form(): ...@@ -59,13 +59,12 @@ def compile_form():
from dune.perftool.options import get_option from dune.perftool.options import get_option
formdatas, data = read_ufl(get_option("uflfile")) formdatas, data = read_ufl(get_option("uflfile"))
from dune.perftool.generation import cache_context, global_context from dune.perftool.generation import global_context
with global_context(data=data, formdatas=formdatas): with global_context(data=data, formdatas=formdatas):
# Generate driver file # Generate driver file
if get_option("driver_file"): if get_option("driver_file"):
with cache_context('driver', delete=True): from dune.perftool.pdelab.driver import generate_driver
from dune.perftool.pdelab.driver import generate_driver generate_driver(formdatas, data)
generate_driver(formdatas, data)
# In case of multiple forms: Genarate one file that includes all localoperator files # In case of multiple forms: Genarate one file that includes all localoperator files
if len(formdatas) > 1: if len(formdatas) > 1:
......
...@@ -30,7 +30,6 @@ from dune.perftool.generation.loopy import (constantarg, ...@@ -30,7 +30,6 @@ from dune.perftool.generation.loopy import (constantarg,
valuearg, valuearg,
) )
from dune.perftool.generation.context import (cache_context, from dune.perftool.generation.context import (global_context,
global_context,
get_global_context_value, get_global_context_value,
) )
""" This module provides the infrastructure for loopy preambles that have """ This module provides the memoization infrastructure for code
a complex requirement structure. This includes: generating functions.
* preambles triggering the creation of other preambles
* a caching mechanism to avoid duplicated preambles where harmful
""" """
from __future__ import absolute_import
from pytools import memoize
# Store a global list of generator functions # Store a global list of generator functions
_generators = [] _generators = []
...@@ -58,69 +52,32 @@ def no_caching(*a, **k): ...@@ -58,69 +52,32 @@ def no_caching(*a, **k):
return _GlobalCounter().get() return _GlobalCounter().get()
class _CacheItemMeta(type):
""" A meta class for cache items. Keyword arguments are forwarded
th decorator factory below (check the documentation there)
"""
def __new__(cls, name, bases, d, counted=False, on_store=lambda x: x, item_tags=[]):
rettype = type(name, bases, d)
if counted:
original_on_store = on_store
def add_count(x):
count = _GlobalCounter().get()
if isinstance(x, tuple):
return (count, original_on_store(*x))
else:
return (count, original_on_store(x))
on_store = add_count
def _init(s, x):
if isinstance(x, tuple) and not counted:
s.content = on_store(*x)
else:
s.content = on_store(x)
s.tags = item_tags
s.counted = counted
setattr(rettype, '__init__', _init)
return rettype
@memoize(use_kwargs=True)
def _construct_cache_item_type(name, **kwargs):
""" Wrap the generation of cache item types from the meta class.
At the same time, memoization assures that types are the same for
multiple cache items
"""
return _CacheItemMeta.__new__(_CacheItemMeta, name, (), {}, **kwargs)
class _RegisteredFunction(object): class _RegisteredFunction(object):
""" The data structure for a function that accesses UFL2LoopyDataCache """ """ The data structure for a function that accesses UFL2LoopyDataCache """
def __init__(self, func, def __init__(self, func,
cache_key_generator=lambda *a, **kw: a, cache_key_generator=lambda *a, **kw: a,
**kwargs counted=False,
on_store=lambda x: x,
item_tags=()
): ):
self.func = func self.func = func
self.cache_key_generator = cache_key_generator self.cache_key_generator = cache_key_generator
self.counted = counted
self.on_store = on_store
self.item_tags = item_tags
# Initialize the memoization cache # Initialize the memoization cache
self._memoize_cache = {} self._memoize_cache = {}
# Append the current context item tags to the given ones
from dune.perftool.generation.context import get_context_tags
if 'item_tags' in kwargs:
kwargs['item_tags'] = tuple(kwargs['item_tags']) + get_context_tags()
else:
kwargs['item_tags'] = get_context_tags()
self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs)
# Register this generator function # Register this generator function
_generators.append(self) _generators.append(self)
def _get_content(self, key):
if self.counted:
return self._memoize_cache[key][1]
else:
return self._memoize_cache[key]
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Get the cache key from the given arguments # Get the cache key from the given arguments
cache_key = self.cache_key_generator(*args, **kwargs) cache_key = self.cache_key_generator(*args, **kwargs)
...@@ -128,15 +85,17 @@ class _RegisteredFunction(object): ...@@ -128,15 +85,17 @@ class _RegisteredFunction(object):
assert (lambda *a, **k: len(k) == 0)(cache_key) assert (lambda *a, **k: len(k) == 0)(cache_key)
# check whether we have a cache hit # check whether we have a cache hit
if cache_key in self._memoize_cache: if cache_key not in self._memoize_cache:
# and return the result depending on the cache item type # evaluate the original function
return self._memoize_cache[cache_key].content val = self.on_store(self.func(*args, **kwargs))
else: # Maybe wrap it with a counter!
# evaluate the original function and wrap it in a cache item if self.counted:
val = self.func(*args, **kwargs) val = (_GlobalCounter().get(), val)
citem = self.itemtype(val) # and store the result
self._memoize_cache[cache_key] = citem self._memoize_cache[cache_key] = val
return citem.content
# Return the result for immediate usage
return self._get_content(cache_key)
def generator_factory(**factory_kwargs): def generator_factory(**factory_kwargs):
...@@ -219,7 +178,7 @@ class _ConditionDict(dict): ...@@ -219,7 +178,7 @@ class _ConditionDict(dict):
def _filter_cache_items(gen, condition): def _filter_cache_items(gen, condition):
return {k: v for k, v in gen._memoize_cache.items() if eval(condition, _ConditionDict(v.tags))} return {k: v for k, v in gen._memoize_cache.items() if eval(condition, _ConditionDict(gen.item_tags))}
def retrieve_cache_items(condition=True, make_generable=False): def retrieve_cache_items(condition=True, make_generable=False):
...@@ -235,24 +194,22 @@ def retrieve_cache_items(condition=True, make_generable=False): ...@@ -235,24 +194,22 @@ def retrieve_cache_items(condition=True, make_generable=False):
return content return content
# First yield all those items that are not sorted # First yield all those items that are not sorted
for gen in _generators: for gen in filter(lambda g: not g.counted, _generators):
choice = _filter_cache_items(gen, condition).values() for item in _filter_cache_items(gen, condition).values():
for item in choice: yield as_generable(item)
if not item.counted:
yield as_generable(item.content)
# And now the sorted ones # And now the sorted ones
counted_ones = [] counted_ones = []
for gen in _generators: for gen in filter(lambda g: g.counted, _generators):
counted_ones.extend(filter(lambda i: i.counted, _filter_cache_items(gen, condition).values())) counted_ones.extend(_filter_cache_items(gen, condition).values())
for item in sorted(counted_ones, key=lambda i: i.content[0]): for item in sorted(counted_ones, key=lambda i: i[0]):
from collections import Iterable from collections import Iterable
if isinstance(item.content[1], Iterable) and not isinstance(item.content[1], str): if isinstance(item[1], Iterable) and not isinstance(item[1], str):
for l in item.content[1]: for l in item[1]:
yield as_generable(l) yield as_generable(l)
else: else:
yield as_generable(item.content[1]) yield as_generable(item[1])
def delete_cache_items(condition=True, keep=False): def delete_cache_items(condition=True, keep=False):
...@@ -271,5 +228,4 @@ def inspect_generator(gen): ...@@ -271,5 +228,4 @@ def inspect_generator(gen):
print("Inspecting generator function {}".format(gen.func.func_name)) print("Inspecting generator function {}".format(gen.func.func_name))
for k, v in gen._memoize_cache.items(): for k, v in gen._memoize_cache.items():
print(" args: {}".format(k)) print(" args: {}".format(k))
print(" val: {}".format(v.content)) print(" val: {}".format(v))
print(" tags: {}\n".format(v.tags))
""" Context managers for code generation. """ """ Context managers for code generation. """
# Implement a context manager that allows to temporarily define tags globally.
_cache_context_stack = []
class _CacheContext(object):
def __init__(self, tags, delete=False):
self.tags = tags
self.delete = delete
def __enter__(self):
_cache_context_stack.append(self.tags)
def __exit__(self, exc_type, exc_value, traceback):
_cache_context_stack.pop()
if self.delete:
from dune.perftool.generation.cache import delete_cache_items
delete_cache_items(condition=" and ".join(self.tags))
def cache_context(*tags, **kwargs):
return _CacheContext(tags, **kwargs)
def get_context_tags():
result = tuple()
for items in _cache_context_stack:
result = result + items
return result
_global_context_cache = {} _global_context_cache = {}
......
...@@ -172,7 +172,7 @@ def traverse_lfs_tree(arg): ...@@ -172,7 +172,7 @@ def traverse_lfs_tree(arg):
from dune.perftool.pdelab.argument import name_argumentspace from dune.perftool.pdelab.argument import name_argumentspace
lfs_basename = name_argumentspace(arg) lfs_basename = name_argumentspace(arg)
from dune.perftool.pdelab.localoperator import lop_template_gfs from dune.perftool.pdelab.localoperator import lop_template_gfs
gfs_basename = lop_template_gfs(arg)[1] gfs_basename = lop_template_gfs(arg)
# Now start recursively extracting local function spaces and fill the cache with # Now start recursively extracting local function spaces and fill the cache with
# all those values. That way we can later get a correct local function space with # all those values. That way we can later get a correct local function space with
......
...@@ -1417,7 +1417,7 @@ def generate_driver(formdatas, data): ...@@ -1417,7 +1417,7 @@ def generate_driver(formdatas, data):
from dune.perftool.generation import retrieve_cache_items from dune.perftool.generation import retrieve_cache_items
from cgen import FunctionDeclaration, FunctionBody, Block, Value from cgen import FunctionDeclaration, FunctionBody, Block, Value
driver_signature = FunctionDeclaration(Value('void', 'driver'), [Value('int', 'argc'), Value('char**', 'argv')]) driver_signature = FunctionDeclaration(Value('void', 'driver'), [Value('int', 'argc'), Value('char**', 'argv')])
driver_body = Block(contents=[i for i in retrieve_cache_items("driver and preamble", make_generable=True)]) driver_body = Block(contents=[i for i in retrieve_cache_items("preamble", make_generable=True)])
driver = FunctionBody(driver_signature, driver_body) driver = FunctionBody(driver_signature, driver_body)
filename = get_option("driver_file") filename = get_option("driver_file")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment