Skip to content
Snippets Groups Projects
Commit 4b6609e6 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Avoid having generation magic use global context

That was a bad idea from the beginning! It lead to a situation,
where calls from within a generator inherited their context tags.
Should be fixed now.
parent 9cefb0c1
No related branches found
No related tags found
No related merge requests found
......@@ -85,21 +85,31 @@ class _RegisteredFunction(object):
else:
return self._memoize_cache[key]
def call(self, *args, **kwargs):
# Get the cache key from the given arguments
cache_key = self.cache_key_generator(*args, **kwargs)
def __call__(self, *args, **kwargs):
# Modify the kwargs to include any context tags kept with the generator
for tag in self.context_tags:
if tag in self.kwargs and tag not in kwargs:
kwargs[tag] = self.kwargs[tag]
# Make sure that all keyword arguments have vanished from the cache_args
assert (lambda *a, **k: len(k) == 0)(cache_key)
# Keep an additional dictionary without context tags
without_context = {k: v for k, v in kwargs.items() if k not in self.context_tags}
# Add any context tags to the cache key
context_key = tuple(get_global_context_value(t, None) for t in self.context_tags)
# Get the cache key from the given arguments
context_key = tuple(kwargs.get(t, t + "_default") for t in self.context_tags)
cache_key = self.cache_key_generator(*args, **without_context)
cache_key = (cache_key, context_key)
# check whether we have a cache hit
if cache_key not in self._memoize_cache:
# evaluate the original function
val = self.on_store(self.func(*args, **kwargs))
# evaluate the original function: Once with context tags, once without.
# Reason: Some generators use their context tag to pass it on to other
# generators. That should be possible. However, those that do not do this
# get an unknown keyword...
try:
val = self.on_store(self.func(*args, **kwargs))
except:
val = self.on_store(self.func(*args, **without_context))
# Maybe wrap it with a counter!
if self.counted:
val = (get_counter('__cache_counted'), val)
......@@ -109,14 +119,6 @@ class _RegisteredFunction(object):
# Return the result for immediate usage
return self._get_content(cache_key)
def __call__(self, *args, **kwargs):
additional_kw = {k: kwargs[k] for k in kwargs if k in self.context_tags}
for k, v in self.kwargs.items():
additional_kw[k] = v
kwargs = {k: kwargs[k] for k in kwargs if k not in self.context_tags}
with global_context(**additional_kw):
return self.call(*args, **kwargs)
def generator_factory(**factory_kwargs):
""" A function decorator factory
......@@ -151,8 +153,7 @@ def generator_factory(**factory_kwargs):
context_tags: tuple, str
A single tag or tuple thereof, that will be added to the cache key. This
feature can be used to maintain multiple sets of memoized function evaluations,
for example if you generate multiple loopy kernels at the same time. The
given strings are used to look up in the global context manager for a tag.
for example if you generate multiple loopy kernels at the same time.
"""
# Tuplize the item_tags parameter
if "item_tags" in factory_kwargs and isinstance(factory_kwargs["item_tags"], str):
......
......@@ -23,14 +23,14 @@ def include_file(include, system=False):
@generator_factory(item_tags=("clazz", "initializer"), counted=True, context_tags=("classtag",), cache_key_generator=lambda o, p: o)
def initializer_list(obj, params, classtag=None):
def initializer_list(obj, params):
return "{}({})".format(obj, ", ".join(params))
@generator_factory(item_tags=("clazz", "baseclass"), context_tags=("classtag",), counted=True)
def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], **kwargs):
def base_class(baseclass, access=AccessModifier.PUBLIC, construction=[], classtag=None):
if construction:
initializer_list(baseclass, construction, **kwargs)
initializer_list(baseclass, construction, classtag=classtag)
return BaseClass(baseclass, inheritance=access)
......
......@@ -475,12 +475,11 @@ def generate_kernel(integrals):
visitor = UFL2LoopyVisitor(interface, measure, indexmap)
get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
tag = get_global_context_value("kernel")
knl = extract_kernel_from_cache(tag)
knl = extract_kernel_from_cache("kernel_default")
# All items with the kernel tags can be destroyed once a kernel has been generated
from dune.perftool.generation import delete_cache_items
delete_cache_items(tag)
delete_cache_items("kernel_default")
return knl
......@@ -604,8 +603,8 @@ class LoopyKernelMethod(ClassMember):
if initializer_list:
content[-1] = content[-1] + " :"
for init in initializer_list[:-1]:
content.append(" "*4 + init + ",")
content.append(" "*4 + initializer_list[-1])
content.append(" " * 4 + init + ",")
content.append(" " * 4 + initializer_list[-1])
content.append('{')
if kernel is not None:
......
......@@ -84,18 +84,6 @@ def quadrature_points_per_direction():
return nb_qp
@class_member(classtag="operator")
def define_number_of_quadrature_points_per_direction(name):
number_qp = quadrature_points_per_direction()
return "enum {{ {} = {} }};".format(name, str(number_qp))
def name_number_of_quadrature_points_per_direction():
name = "m"
define_number_of_quadrature_points_per_direction(name)
return name
def polynomial_degree():
form = get_global_context_value("formdata").preprocessed_form
return form.coefficients()[0].ufl_element()._degree
......@@ -105,22 +93,10 @@ def basis_functions_per_direction():
return polynomial_degree() + 1
@class_member(classtag="operator")
def define_number_of_basis_functions_per_direction(name):
number_basis = basis_functions_per_direction()
return "enum {{ {} = {} }};".format(name, str(number_basis))
def name_number_of_basis_functions_per_direction():
name = "n"
define_number_of_basis_functions_per_direction(name)
return name
@class_member(classtag="operator")
def define_oned_quadrature_weights(name):
range_field = lop_template_range_field()
number_qp = name_number_of_quadrature_points_per_direction()
number_qp = quadrature_points_per_direction()
return "{} {}[{}];".format(range_field, name, number_qp)
......@@ -134,7 +110,7 @@ def name_oned_quadrature_weights():
@class_member(classtag="operator")
def define_oned_quadrature_points(name):
range_field = lop_template_range_field()
number_qp = name_number_of_quadrature_points_per_direction()
number_qp = quadrature_points_per_direction()
return "{} {}[{}];".format(range_field, name, number_qp)
......@@ -186,7 +162,7 @@ def name_polynomials():
def sort_quadrature_points_weights():
range_field = lop_template_range_field()
domain_field = name_domain_field()
number_qp = name_number_of_quadrature_points_per_direction()
number_qp = quadrature_points_per_direction()
qp = name_oned_quadrature_points()
qw = name_oned_quadrature_weights()
include_file("dune/perftool/sumfact/onedquadrature.hh", filetag="operatorfile")
......@@ -196,7 +172,7 @@ def sort_quadrature_points_weights():
@iname(kernel="operator")
def theta_iname(name, bound):
name = "{}_{}".format(name, bound)
domain(name, bound)
domain(name, bound, kernel="operator")
return name
......@@ -205,9 +181,9 @@ def construct_theta(name, transpose, derivative):
sort_quadrature_points_weights()
if transpose:
shape = (name_number_of_basis_functions_per_direction(), name_number_of_quadrature_points_per_direction())
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
else:
shape = (name_number_of_quadrature_points_per_direction(), name_number_of_basis_functions_per_direction())
shape = (quadrature_points_per_direction(), basis_functions_per_direction())
polynomials = name_polynomials()
qp = name_oned_quadrature_points()
......@@ -217,7 +193,6 @@ def construct_theta(name, transpose, derivative):
# access = "j,i" if transpose else "i,j"
basispol = "dp" if derivative else "p"
polynomial_access = "{},{}[{}]".format(i, qp, j) if transpose else "{},{}[{}]".format(j, qp, i)
return instruction(code="{}.colmajoraccess({},{}) = {}.{}({});".format(name, i, j, polynomials, basispol, polynomial_access),
kernel="operator",
within_inames=frozenset({i, j}),
......
......@@ -7,7 +7,7 @@ from dune.perftool.generation import (backend,
temporary_variable,
)
from dune.perftool.sumfact.amatrix import (name_number_of_quadrature_points_per_direction,
from dune.perftool.sumfact.amatrix import (quadrature_points_per_direction,
name_oned_quadrature_points,
name_oned_quadrature_weights,
)
......@@ -71,7 +71,7 @@ def pymbolic_base_weight():
@iname
def sumfact_quad_iname(d, context):
name = "quad_{}_{}".format(context, d)
domain(name, name_number_of_quadrature_points_per_direction())
domain(name, quadrature_points_per_direction())
return name
......
......@@ -213,18 +213,16 @@ def test_no_caching_function():
def test_multiple_kernels_1():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble
@preamble(kernel="k1")
def pre1():
return "blubb"
@preamble
@preamble(kernel="k2")
def pre2():
return "bla"
with global_context(kernel="k1"):
pre1()
with global_context(kernel="k2"):
pre2()
pre1()
pre2()
preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2
......@@ -241,18 +239,16 @@ def test_multiple_kernels_1():
def test_multiple_kernels_2():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble
@preamble(kernel="k1")
def pre1():
return "blubb"
@preamble
@preamble(kernel="k2")
def pre2():
pre1()
return "bla"
with global_context(kernel="k1"):
with global_context(kernel="k2"):
pre2()
pre1()
pre2()
preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2
......@@ -267,52 +263,6 @@ def test_multiple_kernels_2():
def test_multiple_kernels_3():
preamble = generator_factory(item_tags=("preamble",), context_tags=("kernel",))
@preamble(kernel="k3")
def pre3():
return "foo"
@preamble(kernel="k4")
def pre4():
return "bar"
pre3()
pre4()
preambles = [p for p in retrieve_cache_items("preamble")]
assert len(preambles) == 2
k3, = retrieve_cache_items("k3")
assert k3 == "foo"
k4, = retrieve_cache_items("k4")
assert k4 == "bar"
delete_cache_items()
def test_multiple_kernels_4():
gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True)
with global_context(kernel="k1"):
gen("foo")
with global_context(kernel="k2"):
gen("bar")
assert len([i for i in retrieve_cache_items("tag")]) == 2
k1, = retrieve_cache_items("k1")
assert k1 == "foo"
k2, = retrieve_cache_items("k2")
assert k2 == "bar"
delete_cache_items()
def test_multiple_kernels_5():
gen = generator_factory(item_tags=("tag",), context_tags=("kernel",), no_deco=True)
gen("foo", kernel="k1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment