Skip to content
Snippets Groups Projects
Commit 9772dbff authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fix interoperability of classmethod generators with backend switches

parent 12b3b1cb
No related branches found
No related tags found
No related merge requests found
......@@ -26,8 +26,8 @@ def name_alias(container, lfs, element):
# TODO remove the need for element
@kernel_cached
@backend(interface="pymbolic_coefficient", name="blockstructured")
@kernel_cached
def pymbolic_coefficient(container, lfs, element, index):
# TODO introduce a proper type for local function spaces!
if isinstance(lfs, str):
......
......@@ -87,10 +87,9 @@ class _RegisteredFunction(object):
# Allow order independence of the backend and the generator decorators.
# If backend was applied first, we resolve the issued FuncProxy object
from dune.codegen.generation.backend import FuncProxy, register_backend
from dune.codegen.generation.backend import FuncProxy
if isinstance(self.func, FuncProxy):
register_backend(self.func.interface, self.func.name, self)
self.func = self.func.func
raise NotImplementedError("Please use @backend as the outer decorator if combining with generator decorators")
def _get_content(self, key):
return self._memoize_cache[key].value
......@@ -191,7 +190,12 @@ def generator_factory(**factory_kwargs):
#
if args:
assert len(args) == 1
funcobj = _generators.setdefault(args[0], _RegisteredFunction(args[0], **kwargs))
key = args[0]
if hasattr(key, "__name__") and key.__name__ == '<lambda>':
key = str(kwargs)
funcobj = _generators.setdefault(key, _RegisteredFunction(args[0], **kwargs))
return lambda *a, **ka: funcobj(*a, **ka)
else:
def __dec(f):
......
......@@ -324,8 +324,10 @@ def _vectorize_quadrature_loop(knl, inames, suffix):
def vectorize_quadrature_loop(knl):
# Loop over the quadrature loops that exist in the kernel.
# This is implemented a bit hacky right now...
for key, inames in quadrature_inames._memoize_cache.items():
# This is implemented very hacky right now...
from dune.codegen.generation.cache import _generators as _g
gen = list(filter(lambda i: hasattr(i[0], "__name__") and i[0].__name__ == "quadrature_inames", _g.items()))[0][1]
for key, inames in gen._memoize_cache.items():
element = key[0][0]
if element is None:
suffix = ''
......
......@@ -131,8 +131,8 @@ def name_applycontainer(restriction):
return name
@kernel_cached
@backend(interface="pymbolic_coefficient")
@kernel_cached
def pymbolic_coefficient(container, lfs, index):
# TODO introduce a proper type for local function spaces!
if isinstance(lfs, str):
......
......@@ -10,26 +10,12 @@ def f1():
return 1
@generator_factory()
@backend(interface="foo", name="f2")
def f2():
return 2
@backend(interface="bar", name="f3")
@generator_factory()
def f3():
return 3
@generator_factory()
@backend(interface="bar", name="f4")
def f4():
return 4
def test_backend():
assert get_backend(interface="foo", selector=lambda: "f1")() == 1
assert get_backend(interface="foo", selector=lambda: "f2")() == 2
assert get_backend(interface="bar", selector=lambda: "f3")() == 3
assert get_backend(interface="bar", selector=lambda: "f4")() == 4
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment