Skip to content
Snippets Groups Projects
Commit 4b6609e6 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Avoid having generation magic use global context

That was a bad idea from the beginning! It lead to a situation,
where calls from within a generator inherited their context tags.
Should be fixed now.
parent 9cefb0c1
No related branches found
No related tags found
No related merge requests found
...@@ -85,21 +85,31 @@ class _RegisteredFunction(object): ...@@ -85,21 +85,31 @@ class _RegisteredFunction(object):
else: else:
return self._memoize_cache[key] return self._memoize_cache[key]
def call(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Get the cache key from the given arguments # Modify the kwargs to include any context tags kept with the generator
cache_key = self.cache_key_generator(*args, **kwargs) for tag in self.context_tags:
if tag in self.kwargs and tag not in kwargs:
kwargs[tag] = self.kwargs[tag]
# Make sure that all keyword arguments have vanished from the cache_args # Keep an additional dictionary without context tags
assert (lambda *a, **k: len(k) == 0)(cache_key) without_context = {k: v for k, v in kwargs.items() if k not in self.context_tags}
# Add any context tags to the cache key # Get the cache key from the given arguments
context_key = tuple(get_global_context_value(t, None) for t in self.context_tags) context_key = tuple(kwargs.get(t, t + "_default") for t in self.context_tags)
cache_key = self.cache_key_generator(*args, **without_context)
cache_key = (cache_key, context_key) cache_key = (cache_key, context_key)
# check whether we have a cache hit # check whether we have a cache hit
if cache_key not in self._memoize_cache: if cache_key not in self._memoize_cache:
# evaluate the original function # evaluate the original function: Once with context tags, once without.
val = self.on_store(self.func(*args, **kwargs)) # Reason: Some generators use their context tag to pass it on to other
# generators. That should be possible. However, those that do not do this
# get an unknown keyword...
try:
val = self.on_store(self.func(*args, **kwargs))
except:
val = self.on_store(self.func(*args, **without_context))
# Maybe wrap it with a counter! # Maybe wrap it with a counter!
if self.counted: if self.counted:
val = (get_counter('__cache_counted'), val) val = (get_counter('__cache_counted'), val)
...@@ -109,14 +119,6 @@ class _RegisteredFunction(object): ...@@ -109,14 +119,6 @@ class _RegisteredFunction(object):
# Return the result for immediate usage # Return the result for immediate usage
return self._get_content(cache_key) return self._get_content(cache_key)
def __call__(self, *args, **kwargs):
additional_kw = {k: kwargs[k] for k in kwargs if k in self.context_tags}
for k, v in self.kwargs.items():
additional_kw[k] = v
kwargs = {k: kwargs[k] for k in kwargs if k not in self.context_tags}
with global_context(**additional_kw):
return self.call(*args, **kwargs)
def generator_factory(**factory_kwargs): def generator_factory(**factory_kwargs):
""" A function decorator factory """ A function decorator factory
...@@ -151,8 +153,7 @@ def generator_factory(**factory_kwargs): ...@@ -151,8 +153,7 @@ def generator_factory(**factory_kwargs):
context_tags: tuple, str context_tags: tuple, str
A single tag or tuple thereof, that will be added to the cache key. This A single tag or tuple thereof, that will be added to the cache key. This
feature can be used to maintain multiple sets of memoized function evaluations, feature can be used to maintain multiple sets of memoized function evaluations,
for example if you generate multiple loopy kernels at the same time. The for example if you generate multiple loopy kernels at the same time.
given strings are used to look up in the global context manager for a tag.
""" """
# Tuplize the item_tags parameter # Tuplize the item_tags parameter
if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str): if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str):
......
...@@ -23,14 +23,14 @@ def include_file(include, system=False): ...@@ -23,14 +23,14 @@ def include_file(include, system=False):
@generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o) @generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o)
def initializer_list(obj, params, classtag=None): def initializer_list(obj, params):
return "{}({})".format(obj, ", ".join(params)) return "{}({})".format(obj, ", ".join(params))
@generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True) @generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True)
def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], **kwargs): def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], classtag=None):
if construction: if construction:
initializer_list(baseclass, construction, **kwargs) initializer_list(baseclass, construction, classtag=classtag)
return BaseClass(baseclass, inheritance=access) return BaseClass(baseclass, inheritance=access)
......
...@@ -475,12 +475,11 @@ def generate_kernel(integrals): ...@@ -475,12 +475,11 @@ def generate_kernel(integrals):
visitor = UFL2LoopyVisitor(interface, measure, indexmap) visitor = UFL2LoopyVisitor(interface, measure, indexmap)
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
tag = get_global_context_value("kernel") knl = extract_kernel_from_cache("kernel_default")
knl = extract_kernel_from_cache(tag)
# All items with the kernel tags can be destroyed once a kernel has been generated # All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items from dune.perftool.generation import delete_cache_items
delete_cache_items(tag) delete_cache_items("kernel_default")
return knl return knl
...@@ -604,8 +603,8 @@ class LoopyKernelMethod(ClassMember): ...@@ -604,8 +603,8 @@ class LoopyKernelMethod(ClassMember):
if initializer_list: if initializer_list:
content[-1] = content[-1] + " :" content[-1] = content[-1] + " :"
for init in initializer_list[:-1]: for init in initializer_list[:-1]:
content.append(" "*4 + init + ",") content.append(" " * 4 + init + ",")
content.append(" "*4 + initializer_list[-1]) content.append(" " * 4 + initializer_list[-1])
content.append('{') content.append('{')
if kernel is not None: if kernel is not None:
......
...@@ -84,18 +84,6 @@ def quadrature_points_per_direction(): ...@@ -84,18 +84,6 @@ def quadrature_points_per_direction():
return nb_qp return nb_qp
@class_member(classtag="operator")
def define_number_of_quadrature_points_per_direction(name):
number_qp = quadrature_points_per_direction()
return "enum {{ {} = {} }};".format(name, str(number_qp))
def name_number_of_quadrature_points_per_direction():
name = "m"
define_number_of_quadrature_points_per_direction(name)
return name
def polynomial_degree(): def polynomial_degree():
form = get_global_context_value("formdata").preprocessed_form form = get_global_context_value("formdata").preprocessed_form
return form.coefficients()[0].ufl_element()._degree return form.coefficients()[0].ufl_element()._degree
...@@ -105,22 +93,10 @@ def basis_functions_per_direction(): ...@@ -105,22 +93,10 @@ def basis_functions_per_direction():
return polynomial_degree() + 1 return polynomial_degree() + 1
@class_member(classtag="operator")
def define_number_of_basis_functions_per_direction(name):
number_basis = basis_functions_per_direction()
return "enum {{ {} = {} }};".format(name, str(number_basis))
def name_number_of_basis_functions_per_direction():
name = "n"
define_number_of_basis_functions_per_direction(name)
return name
@class_member(classtag="operator") @class_member(classtag="operator")
def define_oned_quadrature_weights(name): def define_oned_quadrature_weights(name):
range_field = lop_template_range_field() range_field = lop_template_range_field()
number_qp = name_number_of_quadrature_points_per_direction() number_qp = quadrature_points_per_direction()
return "{} {}[{}];".format(range_field, name, number_qp) return "{} {}[{}];".format(range_field, name, number_qp)
...@@ -134,7 +110,7 @@ def name_oned_quadrature_weights(): ...@@ -134,7 +110,7 @@ def name_oned_quadrature_weights():
@class_member(classtag="operator") @class_member(classtag="operator")
def define_oned_quadrature_points(name): def define_oned_quadrature_points(name):
range_field = lop_template_range_field() range_field = lop_template_range_field()
number_qp = name_number_of_quadrature_points_per_direction() number_qp = quadrature_points_per_direction()
return "{} {}[{}];".format(range_field, name, number_qp) return "{} {}[{}];".format(range_field, name, number_qp)
...@@ -186,7 +162,7 @@ def name_polynomials(): ...@@ -186,7 +162,7 @@ def name_polynomials():
def sort_quadrature_points_weights(): def sort_quadrature_points_weights():
range_field = lop_template_range_field() range_field = lop_template_range_field()
domain_field = name_domain_field() domain_field = name_domain_field()
number_qp = name_number_of_quadrature_points_per_direction() number_qp = quadrature_points_per_direction()
qp = name_oned_quadrature_points() qp = name_oned_quadrature_points()
qw = name_oned_quadrature_weights() qw = name_oned_quadrature_weights()
include_file("dune/perftool/sumfact/onedquadrature.hh", filetag="operatorfile") include_file("dune/perftool/sumfact/onedquadrature.hh", filetag="operatorfile")
...@@ -196,7 +172,7 @@ def sort_quadrature_points_weights(): ...@@ -196,7 +172,7 @@ def sort_quadrature_points_weights():
@iname(kernel="operator") @iname(kernel="operator")
def theta_iname(name, bound): def theta_iname(name, bound):
name = "{}_{}".format(name, bound) name = "{}_{}".format(name, bound)
domain(name, bound) domain(name, bound, kernel="operator")
return name return name
...@@ -205,9 +181,9 @@ def construct_theta(name, transpose, derivative): ...@@ -205,9 +181,9 @@ def construct_theta(name, transpose, derivative):
sort_quadrature_points_weights() sort_quadrature_points_weights()
if transpose: if transpose:
shape = (name_number_of_basis_functions_per_direction(), name_number_of_quadrature_points_per_direction()) shape = (basis_functions_per_direction(), quadrature_points_per_direction())
else: else:
shape = (name_number_of_quadrature_points_per_direction(), name_number_of_basis_functions_per_direction()) shape = (quadrature_points_per_direction(), basis_functions_per_direction())
polynomials = name_polynomials() polynomials = name_polynomials()
qp = name_oned_quadrature_points() qp = name_oned_quadrature_points()
...@@ -217,7 +193,6 @@ def construct_theta(name, transpose, derivative): ...@@ -217,7 +193,6 @@ def construct_theta(name, transpose, derivative):
# access = "j,i" if transpose else "i,j" # access = "j,i" if transpose else "i,j"
basispol = "dp" if derivative else "p" basispol = "dp" if derivative else "p"
polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i) polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i)
return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access), return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access),
kernel="operator", kernel="operator",
within_inames=frozenset({i, j}), within_inames=frozenset({i, j}),
......
...@@ -7,7 +7,7 @@ from dune.perftool.generation import (backend, ...@@ -7,7 +7,7 @@ from dune.perftool.generation import (backend,
temporary_variable, temporary_variable,
) )
from dune.perftool.sumfact.amatrix import (name_number_of_quadrature_points_per_direction, from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction,
name_oned_quadrature_points, name_oned_quadrature_points,
name_oned_quadrature_weights, name_oned_quadrature_weights,
) )
...@@ -71,7 +71,7 @@ def pymbolic_base_weight(): ...@@ -71,7 +71,7 @@ def pymbolic_base_weight():
@iname @iname
def sumfact_quad_iname(d, context): def sumfact_quad_iname(d, context):
name = "quad_{}_{}".format(context, d) name = "quad_{}_{}".format(context, d)
domain(name, name_number_of_quadrature_points_per_direction()) domain(name, quadrature_points_per_direction())
return name return name
......
...@@ -213,18 +213,16 @@ def test_no_caching_function(): ...@@ -213,18 +213,16 @@ def test_no_caching_function():
def test_multiple_kernels_1(): def test_multiple_kernels_1():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble @preamble(kernel="k1")
def pre1(): def pre1():
return "blubb" return "blubb"
@preamble @preamble(kernel="k2")
def pre2(): def pre2():
return "bla" return "bla"
with global_context(kernel="k1"): pre1()
pre1() pre2()
with global_context(kernel="k2"):
pre2()
preambles = [p for p in retrieve_cache_items("preamble")] preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2 assert len(preambles) == 2
...@@ -241,18 +239,16 @@ def test_multiple_kernels_1(): ...@@ -241,18 +239,16 @@ def test_multiple_kernels_1():
def test_multiple_kernels_2(): def test_multiple_kernels_2():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",)) preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble @preamble(kernel="k1")
def pre1(): def pre1():
return "blubb" return "blubb"
@preamble @preamble(kernel="k2")
def pre2(): def pre2():
pre1()
return "bla" return "bla"
with global_context(kernel="k1"): pre2()
with global_context(kernel="k2"):
pre2()
pre1()
preambles = [p for p in retrieve_cache_items("preamble")] preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2 assert len(preambles) == 2
...@@ -267,52 +263,6 @@ def test_multiple_kernels_2(): ...@@ -267,52 +263,6 @@ def test_multiple_kernels_2():
def test_multiple_kernels_3(): def test_multiple_kernels_3():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble(kernel="k3")
def pre3():
return "foo"
@preamble(kernel="k4")
def pre4():
return "bar"
pre3()
pre4()
preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2
k3, = retrieve_cache_items("k3")
assert k3 == "foo"
k4, = retrieve_cache_items("k4")
assert k4 == "bar"
delete_cache_items()
def test_multiple_kernels_4():
gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True)
with global_context(kernel="k1"):
gen("foo")
with global_context(kernel="k2"):
gen("bar")
assert len([i for i in retrieve_cache_items("tag")]) == 2
k1, = retrieve_cache_items("k1")
assert k1 == "foo"
k2, = retrieve_cache_items("k2")
assert k2 == "bar"
delete_cache_items()
def test_multiple_kernels_5():
gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True) gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True)
gen("foo", kernel="k1") gen("foo", kernel="k1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment