Skip to content
Snippets Groups Projects
Commit 60a11b54 authored by René Heß's avatar René Heß
Browse files

Store cache objects in memoization

parent 7c7045aa
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,14 @@ def no_caching(*a, **k):
return get_counter('__no_caching')
class _CacheObject(object):
""" Data type of objects stored in memoization cache of _RegisteredFunction"""
def __init__(self, value, count=None):
self.value = value
self.count = count
self.stackframe = None
class _RegisteredFunction(object):
""" The data structure for a function that accesses UFL2LoopyDataCache """
def __init__(self, func,
......@@ -80,10 +88,8 @@ class _RegisteredFunction(object):
self.func = self.func.func
def _get_content(self, key):
if self.counted:
return self._memoize_cache[key][1]
else:
return self._memoize_cache[key]
return self._memoize_cache[key].value
def __call__(self, *args, **kwargs):
# Modify the kwargs to include any context tags kept with the generator
......@@ -110,11 +116,11 @@ class _RegisteredFunction(object):
except TypeError:
val = self.on_store(self.func(*args, **without_context))
# Maybe wrap it with a counter!
# Store cache object
if self.counted:
val = (get_counter('__cache_counted'), val)
# and store the result
self._memoize_cache[cache_key] = val
self._memoize_cache[cache_key] = _CacheObject(val, count=get_counter('__cache_counted'))
else:
self._memoize_cache[cache_key] = _CacheObject(val)
# Return the result for immediate usage
return self._get_content(cache_key)
......@@ -195,9 +201,9 @@ class _ConditionDict(dict):
#
# eval ("True", _ConditionDict(v.tags)
#
# But in this case we want to return True! A normal dictionary would not attempt
# to replace "True" if "True" is not a key. The _ConditionDict obviously has no
# such concerns ;).
# But in this case we want to return True! A normal dictionary
# would not attempt to replace "True" if "True" is not a
# key. The _ConditionDict has no such concerns ;).
if i == "True":
return True
if i == "False":
......@@ -230,20 +236,20 @@ def retrieve_cache_items(condition=True, make_generable=False):
# First yield all those items that are not sorted
for gen in filter(lambda g: not g.counted, _generators):
for item in _filter_cache_items(gen, condition).values():
yield as_generable(item)
yield as_generable(item.value)
# And now the sorted ones
counted_ones = []
for gen in filter(lambda g: g.counted, _generators):
counted_ones.extend(_filter_cache_items(gen, condition).values())
for item in sorted(counted_ones, key=lambda i: i[0]):
for item in sorted(counted_ones, key=lambda i: i.count):
from collections import Iterable
if isinstance(item[1], Iterable) and not isinstance(item[1], str):
for l in item[1]:
if isinstance(item.value, Iterable) and not isinstance(item.value, str):
for l in item.value:
yield as_generable(l)
else:
yield as_generable(item[1])
yield as_generable(item.value)
def delete_cache_items(condition=True, keep=False):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment