diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index e541e71366371039157cb757201f710ffd0a19ca..fc9a6ce42f58f2eecae533a6b7add45dbce450c2 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -31,6 +31,11 @@ from dune.perftool.generation.cpp import (base_class, template_parameter, ) +from dune.perftool.generation.hooks import (hook, + ReturnArg, + run_hook, + ) + from dune.perftool.generation.loopy import (barrier, constantarg, domain, diff --git a/python/dune/perftool/generation/hooks.py b/python/dune/perftool/generation/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..31b35754ad44badfabea090e7c82b0ed55c1871d --- /dev/null +++ b/python/dune/perftool/generation/hooks.py @@ -0,0 +1,53 @@ +""" All the infrastructure code related to adding hooks to the code generation process """ + + +_hooks = {} + + +def hook(hookname): + """ A decorator for hook functions """ + + def _hook(func): + current = _hooks.setdefault(hookname, ()) + current = list(current) + current.append(func) + _hooks[hookname] = tuple(current) + + return func + + return _hook + + +class ReturnArg(object): + """ A wrapper for a hook argument, that will be replaced with + the return value of the previous hook functions. That allows + a chain of function calls like a loopy transformation sequence. + """ + def __init__(self, arg): + self.arg = arg + + +def run_hook(name=None, args=[], kwargs={}): + if name is None: + raise PerftoolError("Running hook requires the hook name!") + + # Handle occurences of ReturnArg in the given arguments + occ = list(isinstance(a, ReturnArg) for a in args) + assert occ.count(True) <= 1 + index = None + if occ.count(True): + index = occ.index(True) + args = list(args) + ret = None + if index is not None: + ret = args[index].arg + + # Run the actual hooks + for hook in _hooks.get(name, ()): + # Modify the args for chained hooks + if index is not None: + args[index] = ret + + ret = hook(*args, **kwargs) + + return ret diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 93b68e7e6f2bc75e8cb9b6346d280573d58d2200..9c7aab0ff537ac0641a9f4cb5807dcf608bf960c 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -28,6 +28,8 @@ from dune.perftool.generation import (backend, post_include, retrieve_cache_functions, retrieve_cache_items, + ReturnArg, + run_hook, template_parameter, ) from dune.perftool.cgen.clazz import (AccessModifier, @@ -465,6 +467,10 @@ def visit_integral(integral): visitor = get_visitor(measure, subdomain_id) visitor.accumulate(integrand) + run_hook(name="after_visit", + args=(visitor,), + ) + def generate_kernel(integrals): logger = logging.getLogger(__name__) @@ -500,6 +506,11 @@ def generate_kernel(integrals): # Clean the cache from any data collected after the dry run delete_cache_items("dryrundata") + # Run preprocessing from custom user code + knl = run_hook(name="loopy_kernel", + args=(ReturnArg(knl),), + ) + return knl diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py index 19ca10359de05dec5154dc194d18a9233645664b..bf6b1af86e7424b79d7da6d1602ee7d0e965a0d6 100644 --- a/python/dune/perftool/ufl/preprocess.py +++ b/python/dune/perftool/ufl/preprocess.py @@ -1,5 +1,6 @@ """ Preprocessing algorithms for UFL forms """ +from dune.perftool.generation import run_hook, ReturnArg import ufl.classes as uc import ufl.algorithms.apply_function_pullbacks as afp @@ -39,6 +40,10 @@ def preprocess_form(form): formdata.preprocessed_form = apply_default_transformations(formdata.preprocessed_form) + # Run preprocessing from custom user code + formdata.preprocessed_form = run_hook(name="preprocess", + args=(ReturnArg(formdata.preprocessed_form),)) + return formdata