Skip to content
Snippets Groups Projects
Commit 2ff0ad9e authored by Dominic Kempf's avatar Dominic Kempf
Browse files

[!277] Add a first implementation of hooks

Merge branch 'feature/code-generation-hooks' into 'master'

ref:dominic/dune-perftool This is the first minimal implementation of how code
generation hooks from downstream projects could look like.

There is a few more things to think about (feel invited to share ideas):

-   \[x\] How to document the arguments and return values expected from hooks
-   \[x\] How to handle multiple hooks registered to the same hook point and
    return values (this is quite relevant once you want to do loopy
    transformations in a hook. It means that you want to "chain" the hooks)

This fixes [#129].

See merge request [dominic/dune-perftool!277]

  [#129]: gitlab.dune-project.org/NoneNone/issues/129
  [dominic/dune-perftool!277]: gitlab.dune-project.org/dominic/dune-perftool/merge_requests/277


Closes #129
parents 6b8463dc 55ad773d
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,11 @@ from dune.perftool.generation.cpp import (base_class,
template_parameter,
)
from dune.perftool.generation.hooks import (hook,
ReturnArg,
run_hook,
)
from dune.perftool.generation.loopy import (barrier,
constantarg,
domain,
......
""" All the infrastructure code related to adding hooks to the code generation process """
_hooks = {}
def hook(hookname):
""" A decorator for hook functions """
def _hook(func):
current = _hooks.setdefault(hookname, ())
current = list(current)
current.append(func)
_hooks[hookname] = tuple(current)
return func
return _hook
class ReturnArg(object):
""" A wrapper for a hook argument, that will be replaced with
the return value of the previous hook functions. That allows
a chain of function calls like a loopy transformation sequence.
"""
def __init__(self, arg):
self.arg = arg
def run_hook(name=None, args=[], kwargs={}):
if name is None:
raise PerftoolError("Running hook requires the hook name!")
# Handle occurences of ReturnArg in the given arguments
occ = list(isinstance(a, ReturnArg) for a in args)
assert occ.count(True) <= 1
index = None
if occ.count(True):
index = occ.index(True)
args = list(args)
ret = None
if index is not None:
ret = args[index].arg
# Run the actual hooks
for hook in _hooks.get(name, ()):
# Modify the args for chained hooks
if index is not None:
args[index] = ret
ret = hook(*args, **kwargs)
return ret
......@@ -28,6 +28,8 @@ from dune.perftool.generation import (backend,
post_include,
retrieve_cache_functions,
retrieve_cache_items,
ReturnArg,
run_hook,
template_parameter,
)
from dune.perftool.cgen.clazz import (AccessModifier,
......@@ -465,6 +467,10 @@ def visit_integral(integral):
visitor = get_visitor(measure, subdomain_id)
visitor.accumulate(integrand)
run_hook(name="after_visit",
args=(visitor,),
)
def generate_kernel(integrals):
logger = logging.getLogger(__name__)
......@@ -500,6 +506,11 @@ def generate_kernel(integrals):
# Clean the cache from any data collected after the dry run
delete_cache_items("dryrundata")
# Run preprocessing from custom user code
knl = run_hook(name="loopy_kernel",
args=(ReturnArg(knl),),
)
return knl
......
""" Preprocessing algorithms for UFL forms """
from dune.perftool.generation import run_hook, ReturnArg
import ufl.classes as uc
import ufl.algorithms.apply_function_pullbacks as afp
......@@ -39,6 +40,10 @@ def preprocess_form(form):
formdata.preprocessed_form = apply_default_transformations(formdata.preprocessed_form)
# Run preprocessing from custom user code
formdata.preprocessed_form = run_hook(name="preprocess",
args=(ReturnArg(formdata.preprocessed_form),))
return formdata
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment