diff --git a/python/dune/perftool/cgen/clazz.py b/python/dune/perftool/cgen/clazz.py index c673b7c495007591e34dd105beb49c7027b72a55..f265638ec68ba122cd209909b7c2704cd6992aaf 100644 --- a/python/dune/perftool/cgen/clazz.py +++ b/python/dune/perftool/cgen/clazz.py @@ -61,57 +61,18 @@ class ClassMember(Generable): yield line + '\n' -class Constructor(Generable): - def __init__(self, block=Block([]), arg_decls=[], clsname=None, initializer_list=[], access=AccessModifier.PUBLIC): - self.clsname = clsname - self.arg_decls = arg_decls - self.access = access - self.block = block - self.il = initializer_list - - def generate(self): - assert self.clsname - - yield '\n' - yield "{}:\n".format(access_modifier_string(self.access)) - yield self.clsname + "(" - if self.arg_decls: - for content in self.arg_decls[0].generate(with_semicolon=False): - yield content - for ad in self.arg_decls[1:]: - yield ", " - for content in ad.generate(with_semicolon=False): - yield content - yield ")\n" - - # add the initializer list - if self.il: - yield " : {}".format(self.il[0]) - - for i in self.il[1:]: - yield ",\n" - yield " {}".format(i) - yield '\n' - - for line in self.block.generate(): - yield line - - class Class(Generable): """ Generator for a templated class """ - def __init__(self, name, base_classes=[], members=[], tparam_decls=[], constructors=[]): + def __init__(self, name, base_classes=[], members=[], tparam_decls=[]): self.name = name self.base_classes = base_classes self.members = members self.tparam_decls = tparam_decls - self.constructors = constructors for bc in base_classes: assert isinstance(bc, BaseClass) for mem in members: assert isinstance(mem, ClassMember) - for con in constructors: - assert isinstance(con, Constructor) def generate(self): # define the class header @@ -139,7 +100,7 @@ class Class(Generable): yield '\n' # Now yield the entire block - block = Block(contents=self.constructors + self.members) + block = Block(contents=self.members) # Yield the block for line in block.generate(): diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 17f9113b09abf154f8896f5a8c081804772629fa..72fad9024cc46d3084b8318a2ff7db4c0c91cc52 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -18,7 +18,6 @@ from dune.perftool.generation.cache import (cached, from dune.perftool.generation.cpp import (base_class, class_basename, class_member, - constructor_block, constructor_parameter, dump_accumulate_timer, include_file, @@ -37,6 +36,7 @@ from dune.perftool.generation.loopy import (barrier, globalarg, iname, instruction, + kernel_cached, noop_instruction, silenced_warning, temporary_variable, diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 53c8eada2d1722bc554bd802527264facbad8ed9..9f7aba9be5c24e06a7f16e98624ba04b08eefa3e 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -1,6 +1,9 @@ """ This module provides the memoization infrastructure for code generating functions. """ +from dune.perftool.generation.context import (get_global_context_value, + global_context, + ) from dune.perftool.generation.counter import get_counter # Store a global list of generator functions @@ -51,13 +54,17 @@ class _RegisteredFunction(object): cache_key_generator=lambda *a, **kw: a, counted=False, on_store=lambda x: x, - item_tags=() + item_tags=(), + context_tags=(), + **kwargs ): self.func = func self.cache_key_generator = cache_key_generator self.counted = counted self.on_store = on_store self.item_tags = item_tags + self.context_tags = context_tags + self.kwargs = kwargs # Initialize the memoization cache self._memoize_cache = {} @@ -78,12 +85,17 @@ class _RegisteredFunction(object): else: return self._memoize_cache[key] - def __call__(self, *args, **kwargs): + def call(self, *args, **kwargs): # Get the cache key from the given arguments cache_key = self.cache_key_generator(*args, **kwargs) + # Make sure that all keyword arguments have vanished from the cache_args assert (lambda *a, **k: len(k) == 0)(cache_key) + # Add any context tags to the cache key + context_key = tuple(get_global_context_value(t, None) for t in self.context_tags) + cache_key = (cache_key, context_key) + # check whether we have a cache hit if cache_key not in self._memoize_cache: # evaluate the original function @@ -97,6 +109,14 @@ class _RegisteredFunction(object): # Return the result for immediate usage return self._get_content(cache_key) + def __call__(self, *args, **kwargs): + additional_kw = {k: kwargs[k] for k in kwargs if k in self.context_tags} + for k, v in self.kwargs.items(): + additional_kw[k] = v + kwargs = {k: kwargs[k] for k in kwargs if k not in self.context_tags} + with global_context(**additional_kw): + return self.call(*args, **kwargs) + def generator_factory(**factory_kwargs): """ A function decorator factory @@ -128,10 +148,17 @@ def generator_factory(**factory_kwargs): items automatically turns to a tuple with the counttag being the first entry. no_deco : bool Instead of a decorator, return a function that uses identity as a body. + context_tags: tuple, str + A single tag or tuple thereof, that will be added to the cache key. This + feature can be used to maintain multiple sets of memoized function evaluations, + for example if you generate multiple loopy kernels at the same time. The + given strings are used to look up in the global context manager for a tag. """ # Tuplize the item_tags parameter if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str): factory_kwargs["item_tags"] = (factory_kwargs["item_tags"],) + if "context_tags" in factory_kwargs and isinstance(factory_kwargs["context_tags"], str): + factory_kwargs["context_tags"] = (factory_kwargs["context_tags"],) no_deco = factory_kwargs.pop("no_deco", False) @@ -157,28 +184,34 @@ cached = generator_factory(item_tags=("default_cached",)) class _ConditionDict(dict): - def __init__(self, tags): - dict.__init__(self) - self.tags = tags - - def __getitem__(self, i): - # If we do not add these special cases the dictionary will return False - # when we execute the following code: - # - # eval ("True", _ConditionDict(v.tags) - # - # But in this case we want to return True! A normal dictionary would not attempt - # to replace "True" if "True" is not a key. The _ConditionDict obviously has no - # such concerns ;). - if i == "True": - return True - if i == "False": - return False - return i in self.tags + def __init__(self, tags): + dict.__init__(self) + self.tags = tags + + def __getitem__(self, i): + # If we do not add these special cases the dictionary will return False + # when we execute the following code: + # + # eval ("True", _ConditionDict(v.tags) + # + # But in this case we want to return True! A normal dictionary would not attempt + # to replace "True" if "True" is not a key. The _ConditionDict obviously has no + # such concerns ;). + if i == "True": + return True + if i == "False": + return False + return i in self.tags def _filter_cache_items(gen, condition): - return {k: v for k, v in gen._memoize_cache.items() if eval(condition, _ConditionDict(gen.item_tags))} + ret = {} + for k, v in gen._memoize_cache.items(): + _, context_tags = k + if eval(condition, _ConditionDict(gen.item_tags + context_tags)): + ret[k] = v + + return ret def retrieve_cache_items(condition=True, make_generable=False): diff --git a/python/dune/perftool/generation/context.py b/python/dune/perftool/generation/context.py index 3766971dd6d469778f71ee74cdc7cd210ed99722..55feda422efbabe48c7663e87fea289585253cde 100644 --- a/python/dune/perftool/generation/context.py +++ b/python/dune/perftool/generation/context.py @@ -12,7 +12,7 @@ class _GlobalContext(object): for k, v in self.kw.items(): # First store existing values of the same keys if k in _global_context_cache: - self.old_kw[k] = v + self.old_kw[k] = _global_context_cache[k] # Now replace the value with the new one _global_context_cache[k] = v diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py index ea8727be7c700898724cc0e7ad903c0a1a1fdd63..507c8217d335936ec3c70ceb845984a734ee2738 100644 --- a/python/dune/perftool/generation/cpp.py +++ b/python/dune/perftool/generation/cpp.py @@ -4,86 +4,44 @@ are commonly needed for code generation """ from dune.perftool.generation import generator_factory -from dune.perftool.cgen.clazz import AccessModifier +from dune.perftool.cgen.clazz import AccessModifier, BaseClass, ClassMember -preamble = generator_factory(item_tags=("preamble",), counted=True) +import cgen -def pre_include(pre, filetag=None, pre_include=True): - assert filetag - gen = generator_factory(item_tags=("file", filetag, "pre_include"), no_deco=True) - return gen(pre) +preamble = generator_factory(item_tags=("preamble",), counted=True, context_tags="kernel") +pre_include = generator_factory(item_tags=("file", "pre_include"), context_tags=("filetag",), no_deco=True) +post_include = generator_factory(item_tags=("file", "post_include"), context_tags=("filetag",), no_deco=True) +class_member = generator_factory(item_tags=("member",), context_tags=("classtag",), on_store=lambda m: ClassMember(m), counted=True) +template_parameter = generator_factory(item_tags=("template_param",), context_tags=("classtag",), counted=True) +class_basename = generator_factory(item_tags=("basename",), context_tags=("classtag",)) -def post_include(post, filetag=None, pre_include=True): - assert filetag - gen = generator_factory(item_tags=("file", filetag, "post_include"), no_deco=True) - return gen(post) - - -def include_file(include, filetag=None, system=False): - assert filetag - from cgen import Include - gen = generator_factory(on_store=lambda i: Include(i, system=system), item_tags=("file", filetag, "include"), no_deco=True) - return gen(include) +@generator_factory(item_tags=("file", "include"), context_tags=("filetag",)) +def include_file(include, system=False): + return cgen.Include(include, system=system) +@generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o) def initializer_list(obj, params, classtag=None): - assert classtag - gen = generator_factory(item_tags=("clazz", classtag, "initializer"), counted=True, no_deco=True, cache_key_generator=lambda *a: a[0]) - return gen("{}({})".format(obj, ", ".join(params))) - + return "{}({})".format(obj, ", ".join(params)) -def base_class(baseclass, classtag=None, access=AccessModifier.PUBLIC, construction=[]): - assert classtag - - from dune.perftool.cgen.clazz import BaseClass - gen = generator_factory(item_tags=("clazz", "baseclass", classtag), on_store=lambda n: BaseClass(n, inheritance=access), counted=True, no_deco=True) +@generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True) +def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], **kwargs): if construction: - initializer_list(baseclass, construction, classtag=classtag) - - return gen(baseclass) - - -def class_member(classtag=None, access=AccessModifier.PRIVATE): - assert classtag - from cgen import Value - from dune.perftool.cgen.clazz import ClassMember - - return generator_factory(item_tags=("clazz", classtag, "member"), on_store=lambda m: ClassMember(m, access=access), counted=True) - + initializer_list(baseclass, construction, **kwargs) -def constructor_parameter(_type, name, classtag=None, constructortag="default"): - assert classtag - assert constructortag - from cgen import Value + return BaseClass(baseclass, inheritance=access) - gen = generator_factory(item_tags=("clazz", classtag, constructortag, "constructor_param"), counted=True, no_deco=True) - return gen(Value(_type, name)) - -def template_parameter(classtag=None): - assert classtag - - return generator_factory(item_tags=("clazz", classtag, "template_param"), counted=True) - - -def class_basename(classtag=None): - assert classtag - - return generator_factory(item_tags=("clazz", classtag, "basename")) - - -def constructor_block(classtag=None): - assert classtag - from dune.perftool.generation import generator_factory - return generator_factory(item_tags=("clazz", classtag, "constructor_block"), counted=True) +@generator_factory(item_tags=("clazz", "constructor_param"), context_tags=("classtag",), counted=True) +def constructor_parameter(_type, name): + return cgen.Value(_type, name) +@generator_factory(item_tags=("dump_timers",)) def dump_accumulate_timer(name): - gen = generator_factory(item_tags=("dump_timers"), no_deco=True) - from dune.perftool.pdelab.localoperator import (name_time_dumper_os, name_time_dumper_reset, name_time_dumper_t, @@ -95,4 +53,4 @@ def dump_accumulate_timer(name): counter = name_time_dumper_counter() code = "DUMP_AND_ACCUMULATE_TIMER({},{},{},{},{});".format(name, os, reset, t, counter) - return gen(code) + return code diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 8e7e835137e0c9ab987db1e7e1c3b7bab1abdf6a..c61fec6a21dbe47e9a4147008ae616471af1a5bc 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -6,46 +6,49 @@ from dune.perftool.generation import (get_counter, no_caching, preamble, ) -from dune.perftool.generation.context import get_global_context_value from dune.perftool.error import PerftoolLoopyError -import loopy -import numpy +import loopy as lp +import numpy as np -iname = generator_factory(item_tags=("iname",)) -function_mangler = generator_factory(item_tags=("mangler",)) -silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True) +iname = generator_factory(item_tags=("iname",), context_tags="kernel") +function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel") +silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel") +kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel") -class DuneGlobalArg(loopy.GlobalArg): - allowed_extra_kwargs = loopy.GlobalArg.allowed_extra_kwargs + ["managed"] +class DuneGlobalArg(lp.GlobalArg): + allowed_extra_kwargs = lp.GlobalArg.allowed_extra_kwargs + ["managed"] @generator_factory(item_tags=("argument", "globalarg"), + context_tags="kernel", cache_key_generator=lambda n, **kw: n) -def globalarg(name, shape=loopy.auto, managed=True, **kw): +def globalarg(name, shape=lp.auto, managed=True, **kw): if isinstance(shape, str): shape = (shape,) - dtype = kw.pop("dtype", numpy.float64) + dtype = kw.pop("dtype", np.float64) return DuneGlobalArg(name, dtype=dtype, shape=shape, managed=managed, **kw) @generator_factory(item_tags=("argument", "constantarg"), + context_tags="kernel", cache_key_generator=lambda n, **kw: n) -def constantarg(name, shape=loopy.auto, **kw): +def constantarg(name, shape=None, **kw): if isinstance(shape, str): shape = (shape,) - dtype = kw.pop("dtype", numpy.float64) - return loopy.GlobalArg(name, dtype=dtype, shape=shape, **kw) + dtype = kw.pop("dtype", np.float64) + return lp.GlobalArg(name, dtype=dtype, shape=shape, **kw) @generator_factory(item_tags=("argument", "valuearg"), + context_tags="kernel", cache_key_generator=lambda n, **kw: n) def valuearg(name, **kw): - return loopy.ValueArg(name, **kw) + return lp.ValueArg(name, **kw) -@generator_factory(item_tags=("domain",)) +@generator_factory(item_tags=("domain",), context_tags="kernel") def domain(iname, shape): if isinstance(shape, str): valuearg(shape) @@ -56,10 +59,12 @@ def get_temporary_name(): return 'expr_{}'.format(str(get_counter('__temporary').zfill(4))) -@generator_factory(item_tags=("temporary",), cache_key_generator=lambda n, **kw: n) +@generator_factory(item_tags=("temporary",), + context_tags="kernel", + cache_key_generator=lambda n, **kw: n) def temporary_variable(name, **kwargs): from dune.perftool.loopy.temporary import DuneTemporaryVariable - return DuneTemporaryVariable(name, scope=loopy.temp_var_scope.PRIVATE, **kwargs) + return DuneTemporaryVariable(name, scope=lp.temp_var_scope.PRIVATE, **kwargs) # Now define generators for instructions. To ease dependency handling of instructions @@ -69,6 +74,7 @@ def temporary_variable(name, **kwargs): @generator_factory(item_tags=("instruction", "cinstruction"), + context_tags="kernel", cache_key_generator=lambda *a, **kw: kw['code'], ) def c_instruction_impl(**kw): @@ -77,24 +83,26 @@ def c_instruction_impl(**kw): kw['assignees'] = frozenset(Variable(i) for i in kw['assignees']) inames = kw.pop('inames', kw.get('forced_iname_deps', [])) - return loopy.CInstruction(inames, **kw) + return lp.CInstruction(inames, **kw) @generator_factory(item_tags=("instruction", "exprinstruction"), + context_tags="kernel", cache_key_generator=lambda *a, **kw: kw['expression'], ) def expr_instruction_impl(**kw): if 'assignees' in kw: from pymbolic.primitives import Variable kw['assignees'] = frozenset(Variable(i) for i in kw['assignees']) - return loopy.ExpressionInstruction(**kw) + return lp.ExpressionInstruction(**kw) @generator_factory(item_tags=("instruction", "callinstruction"), + context_tags="kernel", cache_key_generator=lambda *a, **kw: kw['expression'], ) def call_instruction_impl(**kw): - return loopy.CallInstruction(**kw) + return lp.CallInstruction(**kw) def _insn_cache_key(code=None, expression=None, **kwargs): @@ -105,7 +113,9 @@ def _insn_cache_key(code=None, expression=None, **kwargs): raise ValueError("Please specify either code or expression for instruction!") -@generator_factory(item_tags=("insn_id",), cache_key_generator=_insn_cache_key) +@generator_factory(item_tags=("insn_id",), + context_tags="kernel", + cache_key_generator=_insn_cache_key) def instruction(code=None, expression=None, **kwargs): assert (code is not None) or (expression is not None) assert not ((code is not None) and (expression is not None)) @@ -127,21 +137,26 @@ def instruction(code=None, expression=None, **kwargs): return id -@generator_factory(item_tags=("instruction",), cache_key_generator=lambda **kw: kw['id']) +@generator_factory(item_tags=("instruction",), + context_tags="kernel", + cache_key_generator=lambda **kw: kw['id']) def noop_instruction(**kwargs): - return loopy.NoOpInstruction(**kwargs) + return lp.NoOpInstruction(**kwargs) @generator_factory(item_tags=("transformation",), + context_tags="kernel", cache_key_generator=no_caching, ) def transform(trafo, *args): return (trafo, args) -@generator_factory(item_tags=("instruction", "barrier"), cache_key_generator=lambda **kw: kw['id']) +@generator_factory(item_tags=("instruction", "barrier"), + context_tags="kernel", + cache_key_generator=lambda **kw: kw['id']) def _barrier(**kwargs): - return loopy.BarrierInstruction(**kwargs) + return lp.BarrierInstruction(**kwargs) def barrier(**kwargs): diff --git a/python/dune/perftool/interactive.py b/python/dune/perftool/interactive.py index 8b7fb0cd3be913568b063b6683f8403e91ae515b..f64f99806bbda1db952ef1406891ace86d210d0c 100644 --- a/python/dune/perftool/interactive.py +++ b/python/dune/perftool/interactive.py @@ -3,7 +3,7 @@ from functools import partial from dune.perftool.generation import global_context from dune.perftool.loopy.transformations import get_loopy_transformations -from dune.perftool.pdelab.localoperator import assembly_routine_signature, AssemblyMethod +from dune.perftool.pdelab.localoperator import assembly_routine_signature, LoopyKernelMethod import os @@ -80,7 +80,7 @@ def show_code(which, kernel): with global_context(integral_type=which[0], form_type=which[1]): signature = assembly_routine_signature() - print("".join(AssemblyMethod(signature, kernel).generate())) + print("".join(LoopyKernelMethod(signature, kernel).generate())) print("Press Return to return to the previous menu") input() diff --git a/python/dune/perftool/loopy/buffer.py b/python/dune/perftool/loopy/buffer.py index bd94d995a847e3f4a5eaf2648fc41124fe4d4cea..e7207c2b41aa14e8b6c2c29f552d1aed0dbd7da3 100644 --- a/python/dune/perftool/loopy/buffer.py +++ b/python/dune/perftool/loopy/buffer.py @@ -44,7 +44,7 @@ class FlipFlopBuffer(object): return name -@generator_factory(item_tags=("kernel", "buffer"), cache_key_generator=lambda i, **kw: i) +@generator_factory(item_tags=("buffer"), cache_key_generator=lambda i, **kw: i, context_tags=("kernel",)) def initialize_buffer(identifier, base_storage_size=None, num=2): if base_storage_size is None: raise PerftoolLoopyError("The buffer for identifier {} has not been initialized.".format(identifier)) diff --git a/python/dune/perftool/pdelab/argument.py b/python/dune/perftool/pdelab/argument.py index fcea7ca6612b01c0f6f236a96c9b0a9b6aaf8562..f948414ab1728b8266a35f0d2ea48853f5e487c0 100644 --- a/python/dune/perftool/pdelab/argument.py +++ b/python/dune/perftool/pdelab/argument.py @@ -6,13 +6,13 @@ Namely: """ from dune.perftool.options import get_option -from dune.perftool.generation import (cached, - domain, +from dune.perftool.generation import (domain, function_mangler, iname, globalarg, valuearg, - get_global_context_value + get_global_context_value, + kernel_cached, ) from dune.perftool.pdelab.index import name_index from dune.perftool.pdelab.basis import (evaluate_coefficient, @@ -131,7 +131,7 @@ def name_applycontainer(restriction): return name -@cached +@kernel_cached def pymbolic_coefficient(container, lfs, index): # TODO introduce a proper type for local function spaces! if isinstance(lfs, str): diff --git a/python/dune/perftool/pdelab/basis.py b/python/dune/perftool/pdelab/basis.py index cdbdbd3557b32949d61045383e6c5820d9d60af8..2364b2bcfb1290c0a842342550b5c4402b9987a2 100644 --- a/python/dune/perftool/pdelab/basis.py +++ b/python/dune/perftool/pdelab/basis.py @@ -1,12 +1,12 @@ """ Generators for basis evaluations """ from dune.perftool.generation import (backend, - cached, class_member, generator_factory, get_backend, include_file, instruction, + kernel_cached, preamble, temporary_variable, ) @@ -39,7 +39,7 @@ def type_localbasis_cache(element): return "LocalBasisCacheWithoutReferences<{}>".format(type_gfs(element)) -@class_member("operator") +@class_member(classtag="operator") def define_localbasis_cache(element, name): include_file("dune/perftool/localbasiscache.hh", filetag="operatorfile") t = type_localbasis_cache(element) @@ -65,7 +65,7 @@ def declare_cache_temporary(element, restriction, which): @backend(interface="evaluate_basis") -@cached +@kernel_cached def evaluate_basis(leaf_element, name, restriction): lfs = name_leaf_lfs(leaf_element, restriction) temporary_variable(name, shape=(name_lfs_bound(lfs),), decl_method=declare_cache_temporary(leaf_element, restriction, 'Function')) @@ -93,7 +93,7 @@ def pymbolic_basis(leaf_element, restriction, number, context=''): @backend(interface="evaluate_grad") -@cached +@kernel_cached def evaluate_reference_gradient(leaf_element, name, restriction): lfs = name_leaf_lfs(leaf_element, restriction) temporary_variable(name, shape=(name_lfs_bound(lfs), 1, name_dimension()), decl_method=declare_cache_temporary(leaf_element, restriction, 'Jacobian')) @@ -129,7 +129,7 @@ def shape_as_pymbolic(shape): return tuple(_shape_as_pymbolic(s) for s in shape) -@cached +@kernel_cached def evaluate_coefficient(element, name, container, restriction, component): from ufl.functionview import select_subelement sub_element = select_subelement(element, component) @@ -165,7 +165,7 @@ def evaluate_coefficient(element, name, container, restriction, component): ) -@cached +@kernel_cached def evaluate_coefficient_gradient(element, name, container, restriction, component): # First we determine the rank of the tensor we are talking about from ufl.functionview import select_subelement diff --git a/python/dune/perftool/pdelab/geometry.py b/python/dune/perftool/pdelab/geometry.py index 9d8a1c28d1bfbd50af29e9c3f6c850b60a949de1..7afe2854f597d1a4207c1b8f20088bd11c0b9be0 100644 --- a/python/dune/perftool/pdelab/geometry.py +++ b/python/dune/perftool/pdelab/geometry.py @@ -1,13 +1,13 @@ from dune.perftool.ufl.modified_terminals import Restriction from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.generation import (backend, - cached, domain, get_backend, get_global_context_value, globalarg, iname, include_file, + kernel_cached, preamble, temporary_variable, valuearg, @@ -276,7 +276,7 @@ def type_jacobian_inverse_transposed(restriction): return "typename {}::JacobianInverseTransposed".format(geo) -@cached +@kernel_cached def define_jacobian_inverse_transposed_temporary(restriction): @preamble def _define_jacobian_inverse_transposed_temporary(name, shape, shape_impl): diff --git a/python/dune/perftool/pdelab/index.py b/python/dune/perftool/pdelab/index.py index cc87edfdcb3ff2be39ec3fe4fa136c49e389b5cf..73bf17eb655f95dbe59db620cca367b736009ea8 100644 --- a/python/dune/perftool/pdelab/index.py +++ b/python/dune/perftool/pdelab/index.py @@ -1,10 +1,10 @@ -from dune.perftool.generation import cached +from dune.perftool.generation import kernel_cached from ufl.classes import MultiIndex, Index # Now define some commonly used generators that do not fall into a specific category -@cached +@kernel_cached def name_index(index): if isinstance(index, Index): # This failed for index > 9 because ufl placed curly brackets around diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index e084853782f0d8c44ab85d984977c18f35a0e753..e13ab5288c1391e25c42cfd8d7185bc9f02e0106 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -10,6 +10,7 @@ from dune.perftool.generation import (backend, domain, dump_accumulate_timer, get_backend, + get_global_context_value, global_context, iname, include_file, @@ -54,22 +55,22 @@ def name_localoperator_file(formdata, data): return filename -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_ansatz_gfs(): return "GFSU" -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_test_gfs(): return "GFSV" -@template_parameter("operator") +@template_parameter(classtag="operator") def lop_template_range_field(): return "RF" -@class_member("operator") +@class_member(classtag="operator") def lop_domain_field(name): # TODO: Rethink for not Galerkin Method gfs = lop_template_ansatz_gfs() @@ -100,7 +101,7 @@ def name_initree_constructor_param(): return "iniParams" -@class_member("operator") +@class_member(classtag="operator") def define_initree(name): param_name = name_initree_constructor_param() include_file('dune/common/parametertree.hh', filetag="operatorfile") @@ -117,7 +118,7 @@ def ufl_measure_to_pdelab_measure(which): }.get(which) -@class_member(classtag="operator", access=AccessModifier.PUBLIC) +@class_member(classtag="operator") def _enum_pattern(which): return "enum {{ doPattern{} = true }};".format(which) @@ -138,7 +139,7 @@ def pattern_baseclass(): return _pattern_baseclass(ufl_measure_to_pdelab_measure(integral_type)) -@class_member(classtag="operator", access=AccessModifier.PUBLIC) +@class_member(classtag="operator") def _enum_alpha(which): return "enum {{ doAlpha{} = true }};".format(which) @@ -154,7 +155,7 @@ def name_initree_member(): return "_iniParams" -@class_basename("operator") +@class_basename(classtag="operator") def localoperator_basename(formdata, data): form_name = name_form(formdata, data) return "LocalOperator" + form_name.capitalize() @@ -474,16 +475,31 @@ def generate_kernel(integrals): visitor = UFL2LoopyVisitor(interface, measure, indexmap) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) + tag = get_global_context_value("kernel") + knl = extract_kernel_from_cache(tag) + + # All items with the kernel tags can be destroyed once a kernel has been generated + from dune.perftool.generation import delete_cache_items + delete_cache_items(tag) + + return knl + + +def extract_kernel_from_cache(tag): # Extract the information, which is needed to create a loopy kernel. # First extracting it, might be useful to alter it before kernel generation. from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items from dune.perftool.loopy.target import DuneTarget - domains = [i for i in retrieve_cache_items("domain")] - instructions = [i for i in retrieve_cache_items("instruction")] - temporaries = {i.name: i for i in retrieve_cache_items("temporary")} - arguments = [i for i in retrieve_cache_items("argument")] - silenced = [l for l in retrieve_cache_items("silenced_warning")] - transformations = [t for t in retrieve_cache_items("transformation")] + domains = [i for i in retrieve_cache_items("{} and domain".format(tag))] + + if not domains: + domains = ["{[stupid] : 0<=stupid<1}"] + + instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))] + temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))} + arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))] + silenced = [l for l in retrieve_cache_items("{} and silenced_warning".format(tag))] + transformations = [t for t in retrieve_cache_items("{} and transformation".format(tag))] # Construct an options object from loopy import Options @@ -515,25 +531,22 @@ def generate_kernel(integrals): if get_option("sumfact"): # Vectorization of the quadrature loop insns = [i.id for i in lp.find_instructions(kernel, lp.match.Tagged("quadvec"))] - from dune.perftool.sumfact.quadrature import quadrature_inames - inames = quadrature_inames() + if insns: + from dune.perftool.sumfact.quadrature import quadrature_inames + inames = quadrature_inames() - from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate - kernel = collect_vector_data_rotate(kernel, insns, inames) + from dune.perftool.loopy.transformations.collect_rotate import collect_vector_data_rotate + kernel = collect_vector_data_rotate(kernel, insns, inames) else: raise NotImplementedError("Only vectorizing sumfactoized code right now!") # Now add the preambles to the kernel - preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("preamble"))] + preambles = [(i, p) for i, p in enumerate(retrieve_cache_items("{} and preamble".format(tag)))] kernel = kernel.copy(preambles=preambles) # Do the loopy preprocessing! kernel = preprocess_kernel(kernel) - # All items with the kernel tags can be destroyed once a kernel has been generated - from dune.perftool.generation import delete_cache_items - delete_cache_items("(not file) and (not clazz)") - return kernel @@ -581,11 +594,19 @@ class TimerMethod(ClassMember): ClassMember.__init__(self, content) -class AssemblyMethod(ClassMember): - def __init__(self, signature, kernel, filename): +class LoopyKernelMethod(ClassMember): + def __init__(self, signature, kernel, add_timings=True, initializer_list=[]): from loopy import generate_body from cgen import LiteralLines, Block content = signature + + # Add initializer list if this is a constructor + if initializer_list: + content[-1] = content[-1] + " :" + for init in initializer_list[:-1]: + content.append(" "*4 + init + ",") + content.append(" "*4 + initializer_list[-1]) + content.append('{') if kernel is not None: # Add kernel preamble @@ -593,7 +614,7 @@ class AssemblyMethod(ClassMember): content.append(' ' + p) # Start timer - if get_option('timer'): + if add_timings and get_option('timer'): timer_name = assembler_routine_name() + '_kernel' post_include('HP_DECLARE_TIMER({});'.format(timer_name), filetag='operatorfile') content.append(' ' + 'HP_TIMER_START({});'.format(timer_name)) @@ -603,7 +624,7 @@ class AssemblyMethod(ClassMember): content.extend(l for l in generate_body(kernel).split('\n')[1:-1]) # Stop timer - if get_option('timer'): + if add_timings and get_option('timer'): content.append(' ' + 'HP_TIMER_STOP({});'.format(timer_name)) content.append('}') @@ -618,17 +639,17 @@ def cgen_class_from_cache(tag, members=[]): base_classes = [bc for bc in retrieve_cache_items('{} and baseclass'.format(tag))] constructor_params = [bc for bc in retrieve_cache_items('{} and constructor_param'.format(tag))] - from cgen import Block - constructor_block = Block(contents=[i for i in retrieve_cache_items("{} and constructor_block".format(tag), make_generable=True)]) il = [i for i in retrieve_cache_items('{} and initializer'.format(tag))] pm = [m for m in retrieve_cache_items('{} and member'.format(tag))] tparams = [i for i in retrieve_cache_items('{} and template_param'.format(tag))] - from dune.perftool.cgen.clazz import Constructor - constructor = Constructor(block=constructor_block, arg_decls=constructor_params, clsname=basename, initializer_list=il) + # Construct the constructor + constructor_knl = extract_kernel_from_cache(tag) + signature = "{}({})".format(basename, ", ".join(next(iter(p.generate(with_semicolon=False))) for p in constructor_params)) + constructor = LoopyKernelMethod([signature], constructor_knl, add_timings=False, initializer_list=il) from dune.perftool.cgen import Class - return Class(basename, base_classes=base_classes, members=members + pm, constructors=[constructor], tparam_decls=tparams) + return Class(basename, base_classes=base_classes, members=[constructor] + members + pm, tparam_decls=tparams) def generate_localoperator_kernels(formdata, data): @@ -682,7 +703,8 @@ def generate_localoperator_kernels(formdata, data): enum_pattern() pattern_baseclass() enum_alpha() - kernel = generate_kernel(form.integrals_by_type(measure)) + with global_context(kernel=assembler_routine_name()): + kernel = generate_kernel(form.integrals_by_type(measure)) # Maybe add numerical differentiation if get_option("numerical_jacobian"): @@ -737,7 +759,8 @@ def generate_localoperator_kernels(formdata, data): with global_context(form_type="jacobian"): for measure in set(i.integral_type() for i in jacform.integrals()): with global_context(integral_type=measure): - kernel = generate_kernel(jacform.integrals_by_type(measure)) + with global_context(kernel=assembler_routine_name()): + kernel = generate_kernel(jacform.integrals_by_type(measure)) operator_kernels[(measure, 'jacobian')] = kernel # Generate dummy functions for those kernels, that vanished in the differentiation process @@ -762,7 +785,8 @@ def generate_localoperator_kernels(formdata, data): with global_context(form_type="jacobian_apply"): for measure in set(i.integral_type() for i in jac_apply_form.integrals()): with global_context(integral_type=measure): - kernel = generate_kernel(jac_apply_form.integrals_by_type(measure)) + with global_context(kernel=assembler_routine_name()): + kernel = generate_kernel(jac_apply_form.integrals_by_type(measure)) operator_kernels[(measure, 'jacobian_apply')] = kernel # Generate dummy functions for those kernels, that vanished in the differentiation process @@ -785,7 +809,7 @@ def generate_localoperator_file(formdata, kernels, filename): it, ft = method with global_context(integral_type=it, form_type=ft): signature = assembly_routine_signature(formdata) - operator_methods.append(AssemblyMethod(signature, kernel, filename)) + operator_methods.append(LoopyKernelMethod(signature, kernel)) if get_option('timer'): include_file('dune/perftool/common/timer.hh', filetag='operatorfile') diff --git a/python/dune/perftool/pdelab/parameter.py b/python/dune/perftool/pdelab/parameter.py index 602182163e0ad531e7992a04525a9cc36acd4802..0d3aa6a29d11f8c520687ded1fb839c95797dcd7 100644 --- a/python/dune/perftool/pdelab/parameter.py +++ b/python/dune/perftool/pdelab/parameter.py @@ -1,12 +1,12 @@ """ Generators for parameter functions """ -from dune.perftool.generation import (cached, - class_basename, +from dune.perftool.generation import (class_basename, class_member, constructor_parameter, generator_factory, get_backend, initializer_list, + kernel_cached, preamble, temporary_variable ) @@ -25,13 +25,13 @@ from dune.perftool.pdelab.localoperator import (class_type_from_cache, ) -@class_basename("parameterclass") +@class_basename(classtag="parameterclass") def parameterclass_basename(formdata, data): lopbase = localoperator_basename(formdata, data) return "{}Params".format(lopbase) -@class_member("operator") +@class_member(classtag="operator") def define_parameterclass(name): _, t = class_type_from_cache("parameterclass") constructor_parameter("const {}&".format(t), name + "_", classtag="operator") @@ -44,7 +44,7 @@ def name_paramclass(): return "param" -@class_member(classtag="parameterclass", access=AccessModifier.PRIVATE) +@class_member(classtag="parameterclass") def define_time(name): initializer_list(name, ["0.0"], classtag="parameterclass") return "double {};".format(name) @@ -55,7 +55,7 @@ def name_time(): return "t" -@class_member("parameterclass", access=AccessModifier.PUBLIC) +@class_member(classtag="parameterclass") def define_set_time_method(): time_name = name_time() # TODO double? @@ -81,7 +81,7 @@ def component_to_tree_path(element, component): return _flatten(subel) -@class_member("parameterclass", access=AccessModifier.PUBLIC) +@class_member(classtag="parameterclass") def define_parameter_function_class_member(name, expr, baset, shape, cell): t = construct_nested_fieldvector(baset, shape) @@ -206,7 +206,7 @@ def construct_nested_fieldvector(t, shape): return 'Dune::FieldVector<{}, {}>'.format(construct_nested_fieldvector(t, shape[1:]), shape[0]) -@cached +@kernel_cached def cell_parameter_function(name, expr, restriction, cellwise_constant, t='double'): shape = expr.ufl_element().value_shape() shape_impl = ('fv',) * len(shape) @@ -218,7 +218,7 @@ def cell_parameter_function(name, expr, restriction, cellwise_constant, t='doubl evaluate_cell_parameter_function(name, restriction) -@cached +@kernel_cached def intersection_parameter_function(name, expr, cellwise_constant, t='double'): shape = expr.ufl_element().value_shape() shape_impl = ('fv',) * len(shape) diff --git a/python/dune/perftool/pdelab/quadrature.py b/python/dune/perftool/pdelab/quadrature.py index 72a42aa73ca20e0543f390ccbd193b5fc2895943..f6be15fd3e5f3e0a64f79e2573d28f2943459a40 100644 --- a/python/dune/perftool/pdelab/quadrature.py +++ b/python/dune/perftool/pdelab/quadrature.py @@ -1,7 +1,6 @@ import numpy from dune.perftool.generation import (backend, - cached, class_member, domain, get_backend, @@ -95,7 +94,7 @@ def fill_quadrature_points_cache(name): return "fillQuadraturePointsCache({}, {}, {});".format(geo, quad_order, name) -@class_member("operator") +@class_member(classtag="operator") def typedef_quadrature_points(name): range_field = lop_template_range_field() dim = _local_dim() @@ -108,7 +107,7 @@ def type_quadrature_points(name): return name -@class_member("operator") +@class_member(classtag="operator") def define_quadrature_points(name): qp_type = type_quadrature_points(name) return "mutable std::vector<{}> {};".format(qp_type, name) @@ -148,7 +147,7 @@ def fill_quadrature_weights_cache(name): return "fillQuadratureWeightsCache({}, {}, {});".format(geo, quad_order, name) -@class_member("operator") +@class_member(classtag="operator") def typedef_quadrature_weights(name): range_field = lop_template_range_field() dim = _local_dim() @@ -167,7 +166,7 @@ def type_quadrature_weights(name): return name -@class_member("operator") +@class_member(classtag="operator") def define_quadrature_weights(name): qw_type = type_quadrature_weights(name) return "mutable std::vector<{}> {};".format(qw_type, name) diff --git a/python/dune/perftool/pdelab/spaces.py b/python/dune/perftool/pdelab/spaces.py index 18fe134f453119a9345d42725cb53bbb3f7d88a8..710f86b0dbfd5093ab21c05459992848c42c90ea 100644 --- a/python/dune/perftool/pdelab/spaces.py +++ b/python/dune/perftool/pdelab/spaces.py @@ -99,7 +99,7 @@ def name_leaf_lfs(leaf_element, restriction, val=None): return val -@generator_factory(cache_key_generator=lambda e, r, c, **kw: (e, r, c)) +@generator_factory(cache_key_generator=lambda e, r, c, **kw: (e, r, c), context_tags=("kernel",)) def name_lfs(element, restriction, component, prefix=None): # Omitting the prefix is only valid upon a second call, which will # result in a cache hit. @@ -178,7 +178,7 @@ def traverse_lfs_tree(arg): type_gfs(arg.argexpr.ufl_element(), basetype=gfs_basename, index_stack=()) -@generator_factory(item_tags=("iname",), cache_key_generator=lambda e, r, c: (e, c)) +@generator_factory(item_tags=("iname",), cache_key_generator=lambda e, r, c: (e, c), context_tags=("kernel",)) def _lfs_iname(element, restriction, context): lfs = name_leaf_lfs(element, restriction) bound = name_lfs_bound(lfs) diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 1d3fa814f18dba7b819783cdfdfd27bd5e898433..5e717f164f6df96e869a8249c3a482f5575feb18 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -5,7 +5,6 @@ from dune.perftool.options import get_option from dune.perftool.pdelab.argument import name_coefficientcontainer from dune.perftool.generation import (class_member, - constructor_block, domain, function_mangler, get_global_context_value, @@ -13,6 +12,8 @@ from dune.perftool.generation import (class_member, iname, include_file, initializer_list, + instruction, + preamble, silenced_warning, temporary_variable, valuearg @@ -58,7 +59,7 @@ def colmajoraccess_mangler(target, func, dtypes): return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.int32))) -@class_member("operator") +@class_member(classtag="operator") def define_alignment(name): alignment = get_option("sumfact_alignment") return "enum {{ {} = {} }};".format(name, str(alignment)) @@ -83,7 +84,7 @@ def quadrature_points_per_direction(): return nb_qp -@class_member("operator") +@class_member(classtag="operator") def define_number_of_quadrature_points_per_direction(name): number_qp = quadrature_points_per_direction() return "enum {{ {} = {} }};".format(name, str(number_qp)) @@ -104,7 +105,7 @@ def basis_functions_per_direction(): return polynomial_degree() + 1 -@class_member("operator") +@class_member(classtag="operator") def define_number_of_basis_functions_per_direction(name): number_basis = basis_functions_per_direction() return "enum {{ {} = {} }};".format(name, str(number_basis)) @@ -116,7 +117,7 @@ def name_number_of_basis_functions_per_direction(): return name -@class_member("operator") +@class_member(classtag="operator") def define_oned_quadrature_weights(name): range_field = lop_template_range_field() number_qp = name_number_of_quadrature_points_per_direction() @@ -130,7 +131,7 @@ def name_oned_quadrature_weights(): return name -@class_member("operator") +@class_member(classtag="operator") def define_oned_quadrature_points(name): range_field = lop_template_range_field() number_qp = name_number_of_quadrature_points_per_direction() @@ -144,7 +145,7 @@ def name_oned_quadrature_points(): return name -@class_member("operator") +@class_member(classtag="operator") def typedef_polynomials(name): range_field = lop_template_range_field() domain_field = name_domain_field() @@ -169,7 +170,7 @@ def type_polynomials(): return name -@class_member("operator") +@class_member(classtag="operator") def define_polynomials(name): polynomials_type = type_polynomials() return "{} {};".format(polynomials_type, name) @@ -181,7 +182,7 @@ def name_polynomials(): return name -@constructor_block("operator") +@preamble(kernel="operator") def sort_quadrature_points_weights(): range_field = lop_template_range_field() domain_field = name_domain_field() @@ -192,7 +193,13 @@ def sort_quadrature_points_weights(): return "onedQuadraturePointsWeights<{}, {}, {}>({}, {});".format(range_field, domain_field, number_qp, qp, qw) -@constructor_block("operator") +@iname(kernel="operator") +def theta_iname(name, bound): + name = "{}_{}".format(name, bound) + domain(name, bound) + return name + + def construct_theta(name, transpose, derivative): # Make sure that the quadrature points are sorted sort_quadrature_points_weights() @@ -204,18 +211,21 @@ def construct_theta(name, transpose, derivative): polynomials = name_polynomials() qp = name_oned_quadrature_points() + i = theta_iname("i", shape[0]) + j = theta_iname("j", shape[1]) + # access = "j,i" if transpose else "i,j" basispol = "dp" if derivative else "p" - polynomial_access = "i,{}[j]".format(qp) if transpose else "j,{}[i]".format(qp) + polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i) - return ["for (std::size_t i=0; i<{}; i++){{".format(shape[0]), - " for (std::size_t j=0; j<{}; j++){{".format(shape[1]), - " {}.colmajoraccess(i,j) = {}.{}({});".format(name, polynomials, basispol, polynomial_access), - " }", - "}"] + return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access), + kernel="operator", + within_inames=frozenset({i, j}), + within_inames_is_final=True, + ) -@class_member("operator") +@class_member(classtag="operator") def typedef_theta(name): include_file("dune/perftool/sumfact/alignedmatvec.hh", filetag="operatorfile") alignment = name_alignment() @@ -229,7 +239,7 @@ def type_theta(): return name -@class_member("operator") +@class_member(classtag="operator") def define_theta(name, shape, transpose, derivative): theta_type = type_theta() initializer_list(name, [str(axis) for axis in shape], classtag="operator") diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 77edc9e0151799baa1f9c60c6f7355d1dbcc2006..0af2d38c0c20ddaaf11cb5569a3a24355df1e8d3 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -4,12 +4,12 @@ NB: Basis evaluation is only needed for the trial function argument in jacobians multiplication with the test function is part of the sum factorization kernel. """ from dune.perftool.generation import (backend, - cached, domain, get_counter, get_global_context_value, iname, instruction, + kernel_cached, temporary_variable, ) from dune.perftool.sumfact.amatrix import (AMatrix, @@ -39,7 +39,7 @@ def name_sumfact_base_buffer(): return name -@cached +@kernel_cached def sumfact_evaluate_coefficient_gradient(element, name, restriction, component): # Get a temporary for the gradient from ufl.functionview import select_subelement @@ -97,7 +97,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) ) -@cached +@kernel_cached def pymbolic_trialfunction_gradient(element, restriction, component): rawname = "gradu" + "_".join(str(c) for c in component) name = restricted_name(rawname, restriction) @@ -108,7 +108,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component): return Variable(name) -@cached +@kernel_cached def pymbolic_trialfunction(element, restriction, component): theta = name_theta() rows = quadrature_points_per_direction() @@ -151,7 +151,7 @@ def lfs_inames(element, restriction, number=1, context=''): @backend(interface="evaluate_basis") -@cached +@kernel_cached def evaluate_basis(element, name, restriction): temporary_variable(name, shape=()) theta = name_theta() @@ -183,7 +183,7 @@ def pymbolic_basis(element, restriction, number): @backend(interface="evaluate_grad") -@cached +@kernel_cached def evaluate_reference_gradient(element, name, restriction): from dune.perftool.pdelab.geometry import name_dimension temporary_variable( diff --git a/python/test/dune/perftool/generation/test_cache.py b/python/test/dune/perftool/generation/test_cache.py index f974a518f5e4f177e123f3b3222986b31c1f5158..47b9bc7f5e261718aaaa25a6b6a4b3c904d35f83 100644 --- a/python/test/dune/perftool/generation/test_cache.py +++ b/python/test/dune/perftool/generation/test_cache.py @@ -1,11 +1,12 @@ from collections import Counter -from dune.perftool.generation.cache import(delete_cache_items, - generator_factory, - no_caching, - retrieve_cache_functions, - retrieve_cache_items, - ) +from dune.perftool.generation import(delete_cache_items, + generator_factory, + global_context, + no_caching, + retrieve_cache_functions, + retrieve_cache_items, + ) def print_cache(): @@ -202,3 +203,127 @@ def test_no_caching_function(): assert compare(["one", "one"], list(retrieve_cache_items("no_caching"))) no_caching_function("two") assert compare(["one", "one", "two"], list(retrieve_cache_items("no_caching"))) + + +# ===================== +# Test multiple kernels +# ===================== + + +def test_multiple_kernels_1(): + preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) + + @preamble + def pre1(): + return "blubb" + + @preamble + def pre2(): + return "bla" + + with global_context(kernel="k1"): + pre1() + with global_context(kernel="k2"): + pre2() + + preambles = [p for p in retrieve_cache_items("preamble")] + assert len(preambles) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "blubb" + + k2, = retrieve_cache_items("k2") + assert k2 == "bla" + + delete_cache_items() + + +def test_multiple_kernels_2(): + preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) + + @preamble + def pre1(): + return "blubb" + + @preamble + def pre2(): + return "bla" + + with global_context(kernel="k1"): + with global_context(kernel="k2"): + pre2() + pre1() + + preambles = [p for p in retrieve_cache_items("preamble")] + assert len(preambles) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "blubb" + + k2, = retrieve_cache_items("k2") + assert k2 == "bla" + + delete_cache_items() + + +def test_multiple_kernels_3(): + preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) + + @preamble(kernel="k3") + def pre3(): + return "foo" + + @preamble(kernel="k4") + def pre4(): + return "bar" + + pre3() + pre4() + + preambles = [p for p in retrieve_cache_items("preamble")] + assert len(preambles) == 2 + + k3, = retrieve_cache_items("k3") + assert k3 == "foo" + + k4, = retrieve_cache_items("k4") + assert k4 == "bar" + + delete_cache_items() + + +def test_multiple_kernels_4(): + gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) + + with global_context(kernel="k1"): + gen("foo") + + with global_context(kernel="k2"): + gen("bar") + + assert len([i for i in retrieve_cache_items("tag")]) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "foo" + + k2, = retrieve_cache_items("k2") + assert k2 == "bar" + + delete_cache_items() + + +def test_multiple_kernels_5(): + gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) + + gen("foo", kernel="k1") + gen("bar", kernel="k2") + + assert len([i for i in retrieve_cache_items("tag")]) == 2 + + k1, = retrieve_cache_items("k1") + assert k1 == "foo" + + k2, = retrieve_cache_items("k2") + assert k2 == "bar" + + delete_cache_items()