diff --git a/python/dune/codegen/blockstructured/argument.py b/python/dune/codegen/blockstructured/argument.py index afe120a8ec142820cf8ea3944538802b721f4469..b0940e669058777886da45938482f4cf747503e2 100644 --- a/python/dune/codegen/blockstructured/argument.py +++ b/python/dune/codegen/blockstructured/argument.py @@ -26,8 +26,8 @@ def name_alias(container, lfs, element): # TODO remove the need for element -@kernel_cached @backend(interface="pymbolic_coefficient", name="blockstructured") +@kernel_cached def pymbolic_coefficient(container, lfs, element, index): # TODO introduce a proper type for local function spaces! if isinstance(lfs, str): diff --git a/python/dune/codegen/generation/cache.py b/python/dune/codegen/generation/cache.py index bc5a1059acaa95ffd867fac3da2fa6d1273f644f..c1622cb1528f78fe4425334343484d963e52d333 100644 --- a/python/dune/codegen/generation/cache.py +++ b/python/dune/codegen/generation/cache.py @@ -87,10 +87,9 @@ class _RegisteredFunction(object): # Allow order independence of the backend and the generator decorators. # If backend was applied first, we resolve the issued FuncProxy object - from dune.codegen.generation.backend import FuncProxy, register_backend + from dune.codegen.generation.backend import FuncProxy if isinstance(self.func, FuncProxy): - register_backend(self.func.interface, self.func.name, self) - self.func = self.func.func + raise NotImplementedError("Please use @backend as the outer decorator if combining with generator decorators") def _get_content(self, key): return self._memoize_cache[key].value @@ -191,7 +190,12 @@ def generator_factory(**factory_kwargs): # if args: assert len(args) == 1 - funcobj = _generators.setdefault(args[0], _RegisteredFunction(args[0], **kwargs)) + key = args[0] + + if hasattr(key, "__name__") and key.__name__ == '<lambda>': + key = str(kwargs) + + funcobj = _generators.setdefault(key, _RegisteredFunction(args[0], **kwargs)) return lambda *a, **ka: funcobj(*a, **ka) else: def __dec(f): diff --git a/python/dune/codegen/loopy/transformations/vectorize_quad.py b/python/dune/codegen/loopy/transformations/vectorize_quad.py index 2ff6361f445ce44e73c659aa34bf7e4e7f02646b..77686ee2355fc4828322437aeb3edcbaeb66770b 100644 --- a/python/dune/codegen/loopy/transformations/vectorize_quad.py +++ b/python/dune/codegen/loopy/transformations/vectorize_quad.py @@ -324,8 +324,10 @@ def _vectorize_quadrature_loop(knl, inames, suffix): def vectorize_quadrature_loop(knl): # Loop over the quadrature loops that exist in the kernel. - # This is implemented a bit hacky right now... - for key, inames in quadrature_inames._memoize_cache.items(): + # This is implemented very hacky right now... + from dune.codegen.generation.cache import _generators as _g + gen = list(filter(lambda i: hasattr(i[0], "__name__") and i[0].__name__ == "quadrature_inames", _g.items()))[0][1] + for key, inames in gen._memoize_cache.items(): element = key[0][0] if element is None: suffix = '' diff --git a/python/dune/codegen/pdelab/argument.py b/python/dune/codegen/pdelab/argument.py index ce16cd1121458ab27325312a727fd74c98d708cc..8972c32c5b7dd38cbad60054737503f39ce93c95 100644 --- a/python/dune/codegen/pdelab/argument.py +++ b/python/dune/codegen/pdelab/argument.py @@ -131,8 +131,8 @@ def name_applycontainer(restriction): return name -@kernel_cached @backend(interface="pymbolic_coefficient") +@kernel_cached def pymbolic_coefficient(container, lfs, index): # TODO introduce a proper type for local function spaces! if isinstance(lfs, str): diff --git a/python/test/dune/codegen/generation/test_backend.py b/python/test/dune/codegen/generation/test_backend.py index 76bda979394f0293f5b9d1a9314a80ce58ba8a66..e7b11e56dc14a18d18bf04ff453b3dc078cc702b 100644 --- a/python/test/dune/codegen/generation/test_backend.py +++ b/python/test/dune/codegen/generation/test_backend.py @@ -10,26 +10,12 @@ def f1(): return 1 -@generator_factory() -@backend(interface="foo", name="f2") -def f2(): - return 2 - - @backend(interface="bar", name="f3") @generator_factory() def f3(): return 3 -@generator_factory() -@backend(interface="bar", name="f4") -def f4(): - return 4 - - def test_backend(): assert get_backend(interface="foo", selector=lambda: "f1")() == 1 - assert get_backend(interface="foo", selector=lambda: "f2")() == 2 assert get_backend(interface="bar", selector=lambda: "f3")() == 3 - assert get_backend(interface="bar", selector=lambda: "f4")() == 4