diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 52f707ce70455bbb1dd3cc6e954aeaa13daed48d..8e537016b7acd691dca9f7425580fdf6f6798eb8 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -10,7 +10,10 @@ import loopy from ufl.algorithms import compute_form_data, read_ufl_file from ufl.algorithms.formfiles import interpret_ufl_namespace -from dune.perftool.generation import delete_cache_items, global_context +from dune.perftool.generation import (delete_cache_items, + global_context, + subst_rule, + ) from dune.perftool.interactive import start_interactive_session from dune.perftool.options import get_option from dune.perftool.pdelab.driver import generate_driver @@ -82,6 +85,9 @@ def read_ufl(uflfile): # Enrich data by some additional objects we deem worth keeping if get_option("exact_solution_expression"): data.object_by_name[get_option("exact_solution_expression")] = namespace[get_option("exact_solution_expression")] + for name, expr in namespace.items(): + if name.startswith("cse"): + data.object_by_name[name] = namespace[name] formdatas = [] forms = data.forms diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 3e04364509a18fc62e7b6304cf9d100118b1d7b9..0af1aff7fca01d7f6dcd36227ea2f91626640269 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier, kernel_cached, noop_instruction, silenced_warning, + subst_rule, temporary_variable, transform, valuearg, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 603f32f27d89fb03d2406840df559a16d37c0402..ac9f91bd716e2dc5936b053e546742e15707bce8 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -199,3 +199,13 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar globalarg(name, **kwargs) return name + + +@generator_factory(item_tags=("substrule",), + context_tags="kernel", + cache_key_generator=lambda n, e, v: e, + ) +def subst_rule(name, expr, visitor): + if name is None: + raise ValueError("Requested a Substitution Rule that was not predefined!") + return lp.SubstitutionRule(name, (), visitor(expr)) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 09a12e772d1c73803c36aca32d1447a934e4e353..3d39a24368659e3272c09f2a92b497258fa4f834 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -19,6 +19,7 @@ from dune.perftool.generation import (backend, post_include, retrieve_cache_functions, retrieve_cache_items, + subst_rule, template_parameter, ) from dune.perftool.cgen.clazz import (AccessModifier, @@ -412,6 +413,13 @@ def visit_integrals(integrals): interface = PDELabInterface() from dune.perftool.ufl.visitor import UFL2LoopyVisitor visitor = UFL2LoopyVisitor(interface, measure, indexmap) + + # Inspect the UFL file for manual common subexpression elimination stuff! + data = get_global_context_value("data") + for name, expr in data.object_by_name.items(): + if name.startswith("cse"): + subst_rule(name, expr, visitor) + get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index f8d1f7d0ed94025a84f851395a8b20c0bd0988c4..76d169c9109c685e931ac738d2df6bbc72b871e3 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -5,6 +5,7 @@ to pymbolic and loopy. from dune.perftool.error import PerftoolUFLError from dune.perftool.generation import (domain, get_global_context_value, + subst_rule, ) from dune.perftool.ufl.flatoperators import get_operands from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker, @@ -45,9 +46,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # Call base class constructors super(UFL2LoopyVisitor, self).__init__() - # Allow recursion through self.call(..) - call = MultiFunction.__call__ - def __call__(self, o): # Reset some state variables that are reinitialized for each accumulation term self.indices = None @@ -56,6 +54,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return self.call(o) + def call(self, o): + try: + return subst_rule(None, o, self) + except ValueError: + return MultiFunction.__call__(self, o) + # # Form argument/coefficients handlers: # This is where the actual domain specific work happens