From e72b58306855449431508a52570fafb6c0b2c476 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Tue, 23 Oct 2018 13:08:18 +0200 Subject: [PATCH] Make hooks decorators and allow chaining them --- python/dune/perftool/generation/__init__.py | 2 +- python/dune/perftool/generation/hooks.py | 53 +++++++++++++++++--- python/dune/perftool/pdelab/localoperator.py | 4 +- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index d2712471..c668ed2c 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -31,7 +31,7 @@ from dune.perftool.generation.cpp import (base_class, template_parameter, ) -from dune.perftool.generation.hooks import (register_hook, +from dune.perftool.generation.hooks import (hook, run_hook, ) diff --git a/python/dune/perftool/generation/hooks.py b/python/dune/perftool/generation/hooks.py index 4e2daa12..a84d3a01 100644 --- a/python/dune/perftool/generation/hooks.py +++ b/python/dune/perftool/generation/hooks.py @@ -1,16 +1,53 @@ """ All the infrastructure code related to adding hooks to the code generation process """ +from dune.perftool.error import PerftoolError _hooks = {} -def register_hook(hookname, func): - current = _hooks.setdefault(hookname, ()) - current = list(current) - current.append(func) - _hooks[hookname] = tuple(current) +def hook(hookname): + """ A decorator for hook functions """ + def _hook(func): + current = _hooks.setdefault(hookname, ()) + current = list(current) + current.append(func) + _hooks[hookname] = tuple(current) -def run_hook(hookname, *args, **kwargs): - for hook in _hooks.get(hookname, ()): - hook(*args, **kwargs) + return func + + return _hook + + +class ReturnArg(object): + """ A wrapper for a hook argument, that will be replaced with + the return value of the previous hook functions. That allows + a chain of function calls like a loopy transformation sequence. + """ + def __init__(self, arg): + self.arg = arg + + +def run_hook(name=None, args=[], kwargs={}): + if name is None: + raise PerftoolError("Running hook requires the hook name!") + + # Handle occurences of ReturnArg in the given arguments + occ = list(isinstance(a, ReturnArg) for a in args) + assert occ.count(True) <= 1 + index = None + if occ.count(True): + index = occ.index(True) + args = list(args) + if index is not None: + args[index] = args[index].arg + + # Run the actual hooks + for hook in _hooks.get(name, ()): + ret = hook(*args, **kwargs) + + # And modify the args for chained hooks + if index is not None: + args[index] = ret + + return ret \ No newline at end of file diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index f5ac0b8d..ba16965b 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -453,7 +453,9 @@ def visit_integral(integral): visitor = get_visitor(measure, subdomain_id) visitor.accumulate(integrand) - run_hook("after_visit", visitor) + run_hook(name="after_visit", + args=(visitor,), + ) def generate_kernel(integrals): -- GitLab