diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 83bfecc46030f1c528f9e51644fd8563dc98c396..e480fd50bd6eab1f0d625a0ed055e45ddda49488 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -30,7 +30,7 @@ def generate_driver(formdata, filename): from dune.perftool.generation import retrieve_cache_items # Get all preambles for the driver and sort them. - driver_content = [i[1] for i in sorted(retrieve_cache_items(item_name="driver_preamble"), key=lambda x : x[0])] + driver_content = [i[1] for i in sorted(retrieve_cache_items("preamble"), key=lambda x : x[0])] # And flatten out those, that contained nested lists def flatjoin(l): @@ -43,7 +43,7 @@ def generate_driver(formdata, filename): # Write the file. f = open(filename, 'w') - f.write("\n".join(retrieve_cache_items(item_name="driver_include"))) + f.write("\n".join(retrieve_cache_items("include"))) f.write("\n\nvoid driver(int argc, char** argv)\n") f.write("\n".join(driver.generate())) diff --git a/python/dune/perftool/generation.py b/python/dune/perftool/generation.py index 9db423254cfb0a09bee0bc80d298f44e52bfbfd6..0e1ae5ff8b4600a8050cf2f7d416a9da81acf56e 100644 --- a/python/dune/perftool/generation.py +++ b/python/dune/perftool/generation.py @@ -52,7 +52,7 @@ class _CacheItemMeta(type): """ A meta class for cache items. Keyword arguments are forwarded th decorator factory below (check the documentation there) """ - def __new__(cls, name, bases, d, counted=False, on_store=lambda x: x): + def __new__(cls, name, bases, d, counted=False, on_store=lambda x: x, item_tags=[]): rettype = type(name, bases, d) if counted: original_on_store = on_store @@ -64,6 +64,7 @@ class _CacheItemMeta(type): def _init(s, x): s.content = on_store(x) + s.tags = item_tags setattr(rettype, '__init__', _init) @@ -85,31 +86,23 @@ class _RegisteredFunction(object): import sys def __init__(self, func, cache_key_generator=lambda *a : a, - item_name="CacheItemType", **kwargs ): self.func = func self.cache_key_generator = cache_key_generator - self.itemtype = _construct_cache_item_type(item_name, **kwargs) - - if item_name == "CacheItemType": - from IPython import embed; embed() - - # If the specified cache does not exist, create it now. - if self.itemtype not in _cache: - _cache[self.itemtype] = {} + self.itemtype = _construct_cache_item_type("CacheItemType", **kwargs) def __call__(self, *args): # Get the cache key from the given arguments cache_key = (self, _freeze(self.cache_key_generator(*args))) # check whether we have a cache hit - if cache_key in _cache[self.itemtype]: + if cache_key in _cache: # and return the result depending on the cache item type - return _cache[self.itemtype][cache_key].content + return _cache[cache_key].content else: # evaluate the original function and wrap it in a cache item citem = self.itemtype(self.func(*args)) - _cache[self.itemtype][cache_key] = citem + _cache[cache_key] = citem return citem.content @@ -132,9 +125,9 @@ def generator_factory(**factory_kwargs): determines whether to use a caches result. The return type is arbitrary, as it will be turned immutable by the cache machine afterwards. Defaults to identity. - item_name : str - A name to give to the cache item type name. Necessary if you want use - isinstance on the cache with good results... + item_tags : tuple + A tuple of tags (simple strings) to give to the cache items. Items can be + retrieved and deleted by tag. on_store : function A function to apply to the return value of the decorated function before storing in the cache. May be used to apply wrappers. @@ -144,6 +137,10 @@ def generator_factory(**factory_kwargs): no_deco : bool Instead of a decorator, return a function that uses identity as a body. """ + # Tuplize the item_tags parameter + if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str): + factory_kwargs["item_tags"] = (factory_kwargs["item_tags"],) + no_deco = factory_kwargs.pop("no_deco", False) def _dec(*args, **kwargs): @@ -162,34 +159,45 @@ def generator_factory(**factory_kwargs): return _dec -def retrieve_cache_items(item_name=None, generator_function=None, decorator=None): - """ Retrieve items from the cache. These can be selected through various modes: - 1. Setting item_name to the name specified in the decorator. This will return *ALL* - items with that name, especially if multiple items share the same name. - 2. Passing a generator_function will return all items in the cache that are of - the same type as the items generated by the given function. - 3. Passing a decorator will return all items in the cache that are of the same - type as the items generated by the given decorator - """ - # Only one mode can be active - assert sum(bool(t) for t in [item_name, generator_function, decorator]) == 1 - - if decorator: - item_name = decorator.factory_kwargs.get("item_name", ) - - if item_name: - for itemtype in _cache: - if itemtype.__name__ == item_name: - for item in _cache[itemtype].values(): - yield item.content +def retrieve_cache_items(tags, union=True): + """ Retrieve items from the cache. - if generator_function: - for item in _cache[generator_function.itemtype].values(): + If union is True, all items that match one of the given tags are + returned. If unions is False, only items that match all tags are + returned. The items do remain in the cache. + """ + if isinstance(tags, str): + tags = (tags,) + + for item in _cache.values(): + match = False + if union: + for t in item.tags: + if t in tags: + match = True + else: + match = True + for t in item.tags: + if t not in tags: + match = False + if match: yield item.content -def delete_cache(exclude=[]): - """ delete the cache - maybe apply some restrictions later """ - for k in _cache: - if k.__name__ not in exclude: - _cache[k] = {} +def delete_cache_items(tags, union=True): + """ Delete items from the cache. + + If union is True, all items that match one of the given tags are + deleted. If unions is False, only items that match all tags are + deleted. + """ + # TODO this implementation is horribly inefficient, but does the job + removing = retrieve_cache_items(tags, union) + global _cache + _cache = {k: v for k,v in _cache.items() if v not in removing} + +def delete_cache(tags=[], union=True): + # TODO this implementation is horribly inefficient, but does the job + keeping = retrieve_cache_items(tags, union) + global _cache + _cache = {k: v for k,v in _cache.items() if v in keeping} diff --git a/python/dune/perftool/pdelab/__init__.py b/python/dune/perftool/pdelab/__init__.py index fcfea64b783e04420f817661e0cadd7204099efb..23c40d5a4c28155c32d02d8ba0a6d74c822f9efb 100644 --- a/python/dune/perftool/pdelab/__init__.py +++ b/python/dune/perftool/pdelab/__init__.py @@ -2,15 +2,16 @@ # Define the generators that are used throughout all pdelab specific code generations. from dune.perftool.generation import generator_factory -dune_symbol = generator_factory(item_name="dune_symbol") -dune_preamble = generator_factory(item_name="dune_preamble", counted=True) -dune_include = generator_factory(on_store=lambda i: "#include<{}>".format(i), item_name="dune_include", no_deco=True) +dune_symbol = generator_factory(item_tags=("pdelab", "kernel", "symbol")) +dune_preamble = generator_factory(item_tags=("pdelab", "kernel", "preamble"), counted=True) +dune_include = generator_factory(on_store=lambda i: "#include<{}>".format(i), item_tags=("pdelab", "include"), no_deco=True) from dune.perftool.transformer import quadrature_iname from loopy import CInstruction def quadrature_preamble(assignees=[]): - return generator_factory(counted=True, item_name="quadrature_preamble", on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees)) + # TODO: How to enforce the order of quadrature preambles? Counted? + return generator_factory(item_tags=("pdelab", "instruction", "cinstruction", "quadrature"), on_store=lambda code: CInstruction(quadrature_iname(), code, assignees=assignees)) # Now define some commonly used generators that do not fall into a specific category diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 45d12cc60dcb27f688ea2e490bab42523d022300..99f7ea53d31504119b1d346cd291c2a1cc8ebe26 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -3,9 +3,9 @@ from dune.perftool.options import get_option from dune.perftool.generation import generator_factory # Define the generators used in-here -operator_include = generator_factory(item_name="operator_include", on_store=lambda i: "#include<{}>".format(i), no_deco=True) -base_class = generator_factory(item_name="operator_base_classes", counted=True, no_deco=True) -initializer_list = generator_factory(item_name="operator_initializerlist", counted=True) +operator_include = generator_factory(item_tags=("pdelab", "include", "operator"), on_store=lambda i: "#include<{}>".format(i), no_deco=True) +base_class = generator_factory(item_tags=("pdelab", "baseclass", "operator"), counted=True, no_deco=True) +initializer_list = generator_factory(item_tags=("pdelab", "initializer", "operator"), counted=True) @memoize def measure_specific_details(measure): @@ -47,7 +47,7 @@ def generate_term(integrand=None, measure=None): # Delete all non-include parts of the cache. # TODO: add things such as base classes as cache items. from dune.perftool.generation import delete_cache - delete_cache(exclude="dune_include") + delete_cache() # Get the measure specifics specifics = measure_specific_details(measure) @@ -60,12 +60,11 @@ def generate_term(integrand=None, measure=None): # First extracting it, might be useful to alter it before kernel generation. from dune.perftool.generation import retrieve_cache_items from dune.perftool.target import DuneTarget - domains = [i for i in retrieve_cache_items(item_name="loopdomain")] - instructions = [i for i in retrieve_cache_items(item_name="c_instruction")] \ - + [i for i in retrieve_cache_items(item_name="expr_instruction")] \ - + [i[1] for i in retrieve_cache_items(item_name="quadrature_preamble")] - temporaries = {i.name:i for i in retrieve_cache_items(item_name="temporary")} - preambles = [i[1] for i in retrieve_cache_items(item_name="dune_preamble")] + domains = [i for i in retrieve_cache_items("domain")] + instructions = [i for i in retrieve_cache_items("cinstruction")] \ + + [i for i in retrieve_cache_items("exprinstruction")] + temporaries = {i.name:i for i in retrieve_cache_items("temporary")} + preambles = [i[1] for i in retrieve_cache_items("preamble")] print "Printing the information that we found:\n\nDomains:" for d in domains: @@ -142,6 +141,4 @@ def generate_localoperator(ufldata, operatorfile): base_class('Dune::PDELab::LocalOperatorDefaultFlags') - from IPython import embed; embed() - print operator_methods[0] diff --git a/python/dune/perftool/transformer.py b/python/dune/perftool/transformer.py index 63db1e350bec6cf3656785b87587022ddca40cbd..14269bf0a4a913447898891ba0a19164776722fa 100644 --- a/python/dune/perftool/transformer.py +++ b/python/dune/perftool/transformer.py @@ -11,12 +11,12 @@ from dune.perftool.restriction import Restriction # Define the generators that are used here from dune.perftool.generation import generator_factory -loopy_iname = generator_factory(item_name="inames") -loopy_expr_instruction = generator_factory(item_name="expr_instruction", no_deco=True) -loopy_temporary_variable = generator_factory(item_name="temporary", on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True) -loopy_c_instruction = generator_factory(item_name="c_instruction", no_deco=True) +loopy_iname = generator_factory(item_tags=("loopy", "kernel", "iname")) +loopy_expr_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "exprinstruction"), no_deco=True) +loopy_temporary_variable = generator_factory(item_tags=("loopy", "kernel", "temporary"), on_store=lambda n: loopy.TemporaryVariable(n, dtype=numpy.float64), no_deco=True) +loopy_c_instruction = generator_factory(item_tags=("loopy", "kernel", "instruction", "cinstruction"), no_deco=True) -@generator_factory(item_name="loopdomain") +@generator_factory(item_tags=("loopy", "kernel", "domain")) def loopy_domain(iname, shape): return "{{ [{0}] : 0<={0}<{1} }}".format(iname, shape)