From e72b58306855449431508a52570fafb6c0b2c476 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 23 Oct 2018 13:08:18 +0200
Subject: [PATCH] Make hooks decorators and allow chaining them

---
 python/dune/perftool/generation/__init__.py  |  2 +-
 python/dune/perftool/generation/hooks.py     | 53 +++++++++++++++++---
 python/dune/perftool/pdelab/localoperator.py |  4 +-
 3 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index d2712471..c668ed2c 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -31,7 +31,7 @@ from dune.perftool.generation.cpp import (base_class,
                                           template_parameter,
                                           )
 
-from dune.perftool.generation.hooks import (register_hook,
+from dune.perftool.generation.hooks import (hook,
                                             run_hook,
                                             )
 
diff --git a/python/dune/perftool/generation/hooks.py b/python/dune/perftool/generation/hooks.py
index 4e2daa12..a84d3a01 100644
--- a/python/dune/perftool/generation/hooks.py
+++ b/python/dune/perftool/generation/hooks.py
@@ -1,16 +1,53 @@
 """ All the infrastructure code related to adding hooks to the code generation process """
 
+from dune.perftool.error import PerftoolError
 
 _hooks = {}
 
 
-def register_hook(hookname, func):
-    current = _hooks.setdefault(hookname, ())
-    current = list(current)
-    current.append(func)
-    _hooks[hookname] = tuple(current)
+def hook(hookname):
+    """ A decorator for hook functions """
 
+    def _hook(func):
+        current = _hooks.setdefault(hookname, ())
+        current = list(current)
+        current.append(func)
+        _hooks[hookname] = tuple(current)
 
-def run_hook(hookname, *args, **kwargs):
-    for hook in _hooks.get(hookname, ()):
-        hook(*args, **kwargs)
+        return func
+
+    return _hook
+
+
+class ReturnArg(object):
+    """ A wrapper for a hook argument, that will be replaced with
+    the return value of the previous hook functions. That allows
+    a chain of function calls like a loopy transformation sequence.
+    """
+    def __init__(self, arg):
+        self.arg = arg
+
+
+def run_hook(name=None, args=[], kwargs={}):
+    if name is None:
+        raise PerftoolError("Running hook requires the hook name!")
+
+    # Handle occurences of ReturnArg in the given arguments
+    occ = list(isinstance(a, ReturnArg) for a in args)
+    assert occ.count(True) <= 1
+    index = None
+    if occ.count(True):
+        index = occ.index(True)
+    args = list(args)
+    if index is not None:
+        args[index] = args[index].arg
+
+    # Run the actual hooks
+    for hook in _hooks.get(name, ()):
+        ret = hook(*args, **kwargs)
+
+        # And modify the args for chained hooks
+        if index is not None:
+            args[index] = ret
+
+    return ret
\ No newline at end of file
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index f5ac0b8d..ba16965b 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -453,7 +453,9 @@ def visit_integral(integral):
     visitor = get_visitor(measure, subdomain_id)
     visitor.accumulate(integrand)
 
-    run_hook("after_visit", visitor)
+    run_hook(name="after_visit",
+             args=(visitor,),
+             )
 
 
 def generate_kernel(integrals):
-- 
GitLab