diff --git a/python/dune/perftool/generation/hooks.py b/python/dune/perftool/generation/hooks.py index ed785ff70cd39061a822c9329831ffca709a62f7..31b35754ad44badfabea090e7c82b0ed55c1871d 100644 --- a/python/dune/perftool/generation/hooks.py +++ b/python/dune/perftool/generation/hooks.py @@ -38,17 +38,16 @@ def run_hook(name=None, args=[], kwargs={}): if occ.count(True): index = occ.index(True) args = list(args) + ret = None if index is not None: - args[index] = args[index].arg + ret = args[index].arg # Run the actual hooks - ret = None - for hook in _hooks.get(name, ()): - ret = hook(*args, **kwargs) - - # And modify the args for chained hooks + # Modify the args for chained hooks if index is not None: args[index] = ret + ret = hook(*args, **kwargs) + return ret diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index ba16965b894617bdb4fac35f104b1be6eda31204..a41aeef32599edfb787c10f9b6177dbfd2ecf1bd 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -28,6 +28,7 @@ from dune.perftool.generation import (backend, post_include, retrieve_cache_functions, retrieve_cache_items, + ReturnArg, run_hook, template_parameter, ) @@ -492,6 +493,11 @@ def generate_kernel(integrals): # Clean the cache from any data collected after the dry run delete_cache_items("dryrundata") + # Run preprocessing from custom user code + knl = run_hook(name="loopy_kernel", + args=(ReturnArg(knl),), + ) + return knl diff --git a/python/dune/perftool/ufl/preprocess.py b/python/dune/perftool/ufl/preprocess.py index 19ca10359de05dec5154dc194d18a9233645664b..bf6b1af86e7424b79d7da6d1602ee7d0e965a0d6 100644 --- a/python/dune/perftool/ufl/preprocess.py +++ b/python/dune/perftool/ufl/preprocess.py @@ -1,5 +1,6 @@ """ Preprocessing algorithms for UFL forms """ +from dune.perftool.generation import run_hook, ReturnArg import ufl.classes as uc import ufl.algorithms.apply_function_pullbacks as afp @@ -39,6 +40,10 @@ def preprocess_form(form): formdata.preprocessed_form = apply_default_transformations(formdata.preprocessed_form) + # Run preprocessing from custom user code + formdata.preprocessed_form = run_hook(name="preprocess", + args=(ReturnArg(formdata.preprocessed_form),)) + return formdata