From 12b3b1cbfc7a7f25c00d53d8edfb6a98f6b81885 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Wed, 28 Nov 2018 11:58:50 +0100
Subject: [PATCH] Allow generators to be used as class methods!

---
 python/dune/codegen/generation/backend.py |  6 ++---
 python/dune/codegen/generation/cache.py   | 33 ++++++++++++++++-------
 2 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/python/dune/codegen/generation/backend.py b/python/dune/codegen/generation/backend.py
index a45f7ac8..ba793674 100644
--- a/python/dune/codegen/generation/backend.py
+++ b/python/dune/codegen/generation/backend.py
@@ -1,14 +1,14 @@
 from dune.codegen.generation.cache import _RegisteredFunction
 from dune.codegen.options import option_switch
-from pytools import Record
+from pytools import ImmutableRecord
 
 
 _backend_mapping = {}
 
 
-class FuncProxy(Record):
+class FuncProxy(ImmutableRecord):
     def __init__(self, interface, name, func):
-        Record.__init__(self, interface=interface, name=name, func=func)
+        ImmutableRecord.__init__(self, interface=interface, name=name, func=func)
 
     def __call__(self, *args, **kwargs):
         return self.func(*args, **kwargs)
diff --git a/python/dune/codegen/generation/cache.py b/python/dune/codegen/generation/cache.py
index 223ab1bf..bc5a1059 100644
--- a/python/dune/codegen/generation/cache.py
+++ b/python/dune/codegen/generation/cache.py
@@ -10,7 +10,7 @@ from dune.codegen.generation.counter import get_counter
 from dune.codegen.options import get_option
 
 # Store a global list of generator functions
-_generators = []
+_generators = {}
 
 
 def _freeze(data):
@@ -85,9 +85,6 @@ class _RegisteredFunction(object):
         # Initialize the memoization cache
         self._memoize_cache = {}
 
-        # Register this generator function
-        _generators.append(self)
-
         # Allow order independence of the backend and the generator decorators.
         # If backend was applied first, we resolve the issued FuncProxy object
         from dune.codegen.generation.backend import FuncProxy, register_backend
@@ -180,11 +177,27 @@ def generator_factory(**factory_kwargs):
         # Modify the kwargs according to the factorys kwargs
         for k in factory_kwargs:
             kwargs[k] = factory_kwargs[k]
+        # If there args, this function is used as a decorator, as in this example
+        #
+        # @decorator
+        # def foo():
+        #     pass
+        #
+        # If there are no args, this is used as a decorator factory:
+        #
+        # @decorator(bar=42)
+        # def foo():
+        #     pass
+        #
         if args:
             assert len(args) == 1
-            return _RegisteredFunction(args[0], **kwargs)
+            funcobj = _generators.setdefault(args[0], _RegisteredFunction(args[0], **kwargs))
+            return lambda *a, **ka: funcobj(*a, **ka)
         else:
-            return lambda f: _RegisteredFunction(f, **kwargs)
+            def __dec(f):
+                funcobj = _generators.setdefault(f, _RegisteredFunction(f, **kwargs))
+                return lambda *a, **ka: funcobj(*a, **ka)
+            return __dec
 
     if no_deco:
         return _dec(lambda x: x)
@@ -241,13 +254,13 @@ def retrieve_cache_items(condition=True, make_generable=False):
             return content
 
     # First yield all those items that are not sorted
-    for gen in filter(lambda g: not g.counted, _generators):
+    for gen in filter(lambda g: not g.counted, _generators.values()):
         for item in _filter_cache_items(gen, condition).values():
             yield as_generable(item.value)
 
     # And now the sorted ones
     counted_ones = []
-    for gen in filter(lambda g: g.counted, _generators):
+    for gen in filter(lambda g: g.counted, _generators.values()):
         counted_ones.extend(_filter_cache_items(gen, condition).values())
 
     for item in sorted(counted_ones, key=lambda i: i.count):
@@ -264,12 +277,12 @@ def delete_cache_items(condition=True, keep=False):
     if not keep:
         condition = "not ({})".format(condition)
 
-    for gen in _generators:
+    for gen in _generators.values():
         gen._memoize_cache = _filter_cache_items(gen, condition)
 
 
 def retrieve_cache_functions(condition="True"):
-    return [g.func for g in _generators if eval(condition, _ConditionDict(g.item_tags))]
+    return [g.func for g in _generators.values() if eval(condition, _ConditionDict(g.item_tags))]
 
 
 def inspect_generator(gen):
-- 
GitLab