diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py index 9f7aba9be5c24e06a7f16e98624ba04b08eefa3e..58b67988800a8be0cd2200ce340862b81b150764 100644 --- a/python/dune/perftool/generation/cache.py +++ b/python/dune/perftool/generation/cache.py @@ -85,21 +85,31 @@ class _RegisteredFunction(object): else: return self._memoize_cache[key] - def call(self, *args, **kwargs): - # Get the cache key from the given arguments - cache_key = self.cache_key_generator(*args, **kwargs) + def __call__(self, *args, **kwargs): + # Modify the kwargs to include any context tags kept with the generator + for tag in self.context_tags: + if tag in self.kwargs and tag not in kwargs: + kwargs[tag] = self.kwargs[tag] - # Make sure that all keyword arguments have vanished from the cache_args - assert (lambda *a, **k: len(k) == 0)(cache_key) + # Keep an additional dictionary without context tags + without_context = {k: v for k, v in kwargs.items() if k not in self.context_tags} - # Add any context tags to the cache key - context_key = tuple(get_global_context_value(t, None) for t in self.context_tags) + # Get the cache key from the given arguments + context_key = tuple(kwargs.get(t, t + "_default") for t in self.context_tags) + cache_key = self.cache_key_generator(*args, **without_context) cache_key = (cache_key, context_key) # check whether we have a cache hit if cache_key not in self._memoize_cache: - # evaluate the original function - val = self.on_store(self.func(*args, **kwargs)) + # evaluate the original function: Once with context tags, once without. + # Reason: Some generators use their context tag to pass it on to other + # generators. That should be possible. However, those that do not do this + # get an unknown keyword... + try: + val = self.on_store(self.func(*args, **kwargs)) + except: + val = self.on_store(self.func(*args, **without_context)) + # Maybe wrap it with a counter! if self.counted: val = (get_counter('__cache_counted'), val) @@ -109,14 +119,6 @@ class _RegisteredFunction(object): # Return the result for immediate usage return self._get_content(cache_key) - def __call__(self, *args, **kwargs): - additional_kw = {k: kwargs[k] for k in kwargs if k in self.context_tags} - for k, v in self.kwargs.items(): - additional_kw[k] = v - kwargs = {k: kwargs[k] for k in kwargs if k not in self.context_tags} - with global_context(**additional_kw): - return self.call(*args, **kwargs) - def generator_factory(**factory_kwargs): """ A function decorator factory @@ -151,8 +153,7 @@ def generator_factory(**factory_kwargs): context_tags: tuple, str A single tag or tuple thereof, that will be added to the cache key. This feature can be used to maintain multiple sets of memoized function evaluations, - for example if you generate multiple loopy kernels at the same time. The - given strings are used to look up in the global context manager for a tag. + for example if you generate multiple loopy kernels at the same time. """ # Tuplize the item_tags parameter if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str): diff --git a/python/dune/perftool/generation/cpp.py b/python/dune/perftool/generation/cpp.py index 507c8217d335936ec3c70ceb845984a734ee2738..5bb21bb092fb30ea8fc108cc00b4fb3e5f79e5ba 100644 --- a/python/dune/perftool/generation/cpp.py +++ b/python/dune/perftool/generation/cpp.py @@ -23,14 +23,14 @@ def include_file(include, system=False): @generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o) -def initializer_list(obj, params, classtag=None): +def initializer_list(obj, params): return "{}({})".format(obj, ", ".join(params)) @generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True) -def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], **kwargs): +def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], classtag=None): if construction: - initializer_list(baseclass, construction, **kwargs) + initializer_list(baseclass, construction, classtag=classtag) return BaseClass(baseclass, inheritance=access) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index e13ab5288c1391e25c42cfd8d7185bc9f02e0106..0234b4221c30190e983a2859ff857e4642273fdf 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -475,12 +475,11 @@ def generate_kernel(integrals): visitor = UFL2LoopyVisitor(interface, measure, indexmap) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) - tag = get_global_context_value("kernel") - knl = extract_kernel_from_cache(tag) + knl = extract_kernel_from_cache("kernel_default") # All items with the kernel tags can be destroyed once a kernel has been generated from dune.perftool.generation import delete_cache_items - delete_cache_items(tag) + delete_cache_items("kernel_default") return knl @@ -604,8 +603,8 @@ class LoopyKernelMethod(ClassMember): if initializer_list: content[-1] = content[-1] + " :" for init in initializer_list[:-1]: - content.append(" "*4 + init + ",") - content.append(" "*4 + initializer_list[-1]) + content.append(" " * 4 + init + ",") + content.append(" " * 4 + initializer_list[-1]) content.append('{') if kernel is not None: diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 5e717f164f6df96e869a8249c3a482f5575feb18..387fb400fd186e593fb36954c1582e0ae781d65d 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -84,18 +84,6 @@ def quadrature_points_per_direction(): return nb_qp -@class_member(classtag="operator") -def define_number_of_quadrature_points_per_direction(name): - number_qp = quadrature_points_per_direction() - return "enum {{ {} = {} }};".format(name, str(number_qp)) - - -def name_number_of_quadrature_points_per_direction(): - name = "m" - define_number_of_quadrature_points_per_direction(name) - return name - - def polynomial_degree(): form = get_global_context_value("formdata").preprocessed_form return form.coefficients()[0].ufl_element()._degree @@ -105,22 +93,10 @@ def basis_functions_per_direction(): return polynomial_degree() + 1 -@class_member(classtag="operator") -def define_number_of_basis_functions_per_direction(name): - number_basis = basis_functions_per_direction() - return "enum {{ {} = {} }};".format(name, str(number_basis)) - - -def name_number_of_basis_functions_per_direction(): - name = "n" - define_number_of_basis_functions_per_direction(name) - return name - - @class_member(classtag="operator") def define_oned_quadrature_weights(name): range_field = lop_template_range_field() - number_qp = name_number_of_quadrature_points_per_direction() + number_qp = quadrature_points_per_direction() return "{} {}[{}];".format(range_field, name, number_qp) @@ -134,7 +110,7 @@ def name_oned_quadrature_weights(): @class_member(classtag="operator") def define_oned_quadrature_points(name): range_field = lop_template_range_field() - number_qp = name_number_of_quadrature_points_per_direction() + number_qp = quadrature_points_per_direction() return "{} {}[{}];".format(range_field, name, number_qp) @@ -186,7 +162,7 @@ def name_polynomials(): def sort_quadrature_points_weights(): range_field = lop_template_range_field() domain_field = name_domain_field() - number_qp = name_number_of_quadrature_points_per_direction() + number_qp = quadrature_points_per_direction() qp = name_oned_quadrature_points() qw = name_oned_quadrature_weights() include_file("dune/perftool/sumfact/onedquadrature.hh", filetag="operatorfile") @@ -196,7 +172,7 @@ def sort_quadrature_points_weights(): @iname(kernel="operator") def theta_iname(name, bound): name = "{}_{}".format(name, bound) - domain(name, bound) + domain(name, bound, kernel="operator") return name @@ -205,9 +181,9 @@ def construct_theta(name, transpose, derivative): sort_quadrature_points_weights() if transpose: - shape = (name_number_of_basis_functions_per_direction(), name_number_of_quadrature_points_per_direction()) + shape = (basis_functions_per_direction(), quadrature_points_per_direction()) else: - shape = (name_number_of_quadrature_points_per_direction(), name_number_of_basis_functions_per_direction()) + shape = (quadrature_points_per_direction(), basis_functions_per_direction()) polynomials = name_polynomials() qp = name_oned_quadrature_points() @@ -217,7 +193,6 @@ def construct_theta(name, transpose, derivative): # access = "j,i" if transpose else "i,j" basispol = "dp" if derivative else "p" polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i) - return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access), kernel="operator", within_inames=frozenset({i, j}), diff --git a/python/dune/perftool/sumfact/quadrature.py b/python/dune/perftool/sumfact/quadrature.py index dfc6a81db86734c327379d08cd267b984c0dff36..83a5ebb58b1806f6602ca49039adcefb18214a9c 100644 --- a/python/dune/perftool/sumfact/quadrature.py +++ b/python/dune/perftool/sumfact/quadrature.py @@ -7,7 +7,7 @@ from dune.perftool.generation import (backend, temporary_variable, ) -from dune.perftool.sumfact.amatrix import (name_number_of_quadrature_points_per_direction, +from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction, name_oned_quadrature_points, name_oned_quadrature_weights, ) @@ -71,7 +71,7 @@ def pymbolic_base_weight(): @iname def sumfact_quad_iname(d, context): name = "quad_{}_{}".format(context, d) - domain(name, name_number_of_quadrature_points_per_direction()) + domain(name, quadrature_points_per_direction()) return name diff --git a/python/test/dune/perftool/generation/test_cache.py b/python/test/dune/perftool/generation/test_cache.py index 47b9bc7f5e261718aaaa25a6b6a4b3c904d35f83..f28c6e06227631ad2161527367e51c81920119f5 100644 --- a/python/test/dune/perftool/generation/test_cache.py +++ b/python/test/dune/perftool/generation/test_cache.py @@ -213,18 +213,16 @@ def test_no_caching_function(): def test_multiple_kernels_1(): preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) - @preamble + @preamble(kernel="k1") def pre1(): return "blubb" - @preamble + @preamble(kernel="k2") def pre2(): return "bla" - with global_context(kernel="k1"): - pre1() - with global_context(kernel="k2"): - pre2() + pre1() + pre2() preambles = [p for p in retrieve_cache_items("preamble")] assert len(preambles) == 2 @@ -241,18 +239,16 @@ def test_multiple_kernels_1(): def test_multiple_kernels_2(): preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) - @preamble + @preamble(kernel="k1") def pre1(): return "blubb" - @preamble + @preamble(kernel="k2") def pre2(): + pre1() return "bla" - with global_context(kernel="k1"): - with global_context(kernel="k2"): - pre2() - pre1() + pre2() preambles = [p for p in retrieve_cache_items("preamble")] assert len(preambles) == 2 @@ -267,52 +263,6 @@ def test_multiple_kernels_2(): def test_multiple_kernels_3(): - preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) - - @preamble(kernel="k3") - def pre3(): - return "foo" - - @preamble(kernel="k4") - def pre4(): - return "bar" - - pre3() - pre4() - - preambles = [p for p in retrieve_cache_items("preamble")] - assert len(preambles) == 2 - - k3, = retrieve_cache_items("k3") - assert k3 == "foo" - - k4, = retrieve_cache_items("k4") - assert k4 == "bar" - - delete_cache_items() - - -def test_multiple_kernels_4(): - gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) - - with global_context(kernel="k1"): - gen("foo") - - with global_context(kernel="k2"): - gen("bar") - - assert len([i for i in retrieve_cache_items("tag")]) == 2 - - k1, = retrieve_cache_items("k1") - assert k1 == "foo" - - k2, = retrieve_cache_items("k2") - assert k2 == "bar" - - delete_cache_items() - - -def test_multiple_kernels_5(): gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) gen("foo", kernel="k1")