diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index d2712471e4cff347b37a9a5f5452e2a34a4feb72..c668ed2c22fbd37dea9c1cabefb7ac03e1dd68bb 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -31,7 +31,7 @@ from dune.perftool.generation.cpp import (base_class, template_parameter, ) -from dune.perftool.generation.hooks import (register_hook, +from dune.perftool.generation.hooks import (hook, run_hook, ) diff --git a/python/dune/perftool/generation/hooks.py b/python/dune/perftool/generation/hooks.py index 4e2daa12f22a0833b21ef23dc104c103730bfcb0..a84d3a01a6e0592185ee15646af06315f45aa932 100644 --- a/python/dune/perftool/generation/hooks.py +++ b/python/dune/perftool/generation/hooks.py @@ -1,16 +1,53 @@ """ All the infrastructure code related to adding hooks to the code generation process """ +from dune.perftool.error import PerftoolError _hooks = {} -def register_hook(hookname, func): - current = _hooks.setdefault(hookname, ()) - current = list(current) - current.append(func) - _hooks[hookname] = tuple(current) +def hook(hookname): + """ A decorator for hook functions """ + def _hook(func): + current = _hooks.setdefault(hookname, ()) + current = list(current) + current.append(func) + _hooks[hookname] = tuple(current) -def run_hook(hookname, *args, **kwargs): - for hook in _hooks.get(hookname, ()): - hook(*args, **kwargs) + return func + + return _hook + + +class ReturnArg(object): + """ A wrapper for a hook argument, that will be replaced with + the return value of the previous hook functions. That allows + a chain of function calls like a loopy transformation sequence. + """ + def __init__(self, arg): + self.arg = arg + + +def run_hook(name=None, args=[], kwargs={}): + if name is None: + raise PerftoolError("Running hook requires the hook name!") + + # Handle occurences of ReturnArg in the given arguments + occ = list(isinstance(a, ReturnArg) for a in args) + assert occ.count(True) <= 1 + index = None + if occ.count(True): + index = occ.index(True) + args = list(args) + if index is not None: + args[index] = args[index].arg + + # Run the actual hooks + for hook in _hooks.get(name, ()): + ret = hook(*args, **kwargs) + + # And modify the args for chained hooks + if index is not None: + args[index] = ret + + return ret \ No newline at end of file diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index f5ac0b8db07da66e21fe4c3dcc53c08610c2084b..ba16965b894617bdb4fac35f104b1be6eda31204 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -453,7 +453,9 @@ def visit_integral(integral): visitor = get_visitor(measure, subdomain_id) visitor.accumulate(integrand) - run_hook("after_visit", visitor) + run_hook(name="after_visit", + args=(visitor,), + ) def generate_kernel(integrals):