From 3923568efbe35db58998ebb24cf7eef171a1bbc8 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 23 Jan 2017 15:53:04 +0100 Subject: [PATCH] Introduce a general infrastructure to detect common subexpressions --- python/dune/perftool/compile.py | 8 +++++++- python/dune/perftool/generation/__init__.py | 1 + python/dune/perftool/generation/loopy.py | 10 ++++++++++ python/dune/perftool/pdelab/localoperator.py | 8 ++++++++ python/dune/perftool/ufl/visitor.py | 10 +++++++--- 5 files changed, 33 insertions(+), 4 deletions(-) diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py index 52f707ce..8e537016 100644 --- a/python/dune/perftool/compile.py +++ b/python/dune/perftool/compile.py @@ -10,7 +10,10 @@ import loopy from ufl.algorithms import compute_form_data, read_ufl_file from ufl.algorithms.formfiles import interpret_ufl_namespace -from dune.perftool.generation import delete_cache_items, global_context +from dune.perftool.generation import (delete_cache_items, + global_context, + subst_rule, + ) from dune.perftool.interactive import start_interactive_session from dune.perftool.options import get_option from dune.perftool.pdelab.driver import generate_driver @@ -82,6 +85,9 @@ def read_ufl(uflfile): # Enrich data by some additional objects we deem worth keeping if get_option("exact_solution_expression"): data.object_by_name[get_option("exact_solution_expression")] = namespace[get_option("exact_solution_expression")] + for name, expr in namespace.items(): + if name.startswith("cse"): + data.object_by_name[name] = namespace[name] formdatas = [] forms = data.forms diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py index 3e043645..0af1aff7 100644 --- a/python/dune/perftool/generation/__init__.py +++ b/python/dune/perftool/generation/__init__.py @@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier, kernel_cached, noop_instruction, silenced_warning, + subst_rule, temporary_variable, transform, valuearg, diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py index 603f32f2..ac9f91bd 100644 --- a/python/dune/perftool/generation/loopy.py +++ b/python/dune/perftool/generation/loopy.py @@ -199,3 +199,13 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar globalarg(name, **kwargs) return name + + +@generator_factory(item_tags=("substrule",), + context_tags="kernel", + cache_key_generator=lambda n, e, v: e, + ) +def subst_rule(name, expr, visitor): + if name is None: + raise ValueError("Requested a Substitution Rule that was not predefined!") + return lp.SubstitutionRule(name, (), visitor(expr)) diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py index 09a12e77..3d39a243 100644 --- a/python/dune/perftool/pdelab/localoperator.py +++ b/python/dune/perftool/pdelab/localoperator.py @@ -19,6 +19,7 @@ from dune.perftool.generation import (backend, post_include, retrieve_cache_functions, retrieve_cache_items, + subst_rule, template_parameter, ) from dune.perftool.cgen.clazz import (AccessModifier, @@ -412,6 +413,13 @@ def visit_integrals(integrals): interface = PDELabInterface() from dune.perftool.ufl.visitor import UFL2LoopyVisitor visitor = UFL2LoopyVisitor(interface, measure, indexmap) + + # Inspect the UFL file for manual common subexpression elimination stuff! + data = get_global_context_value("data") + for name, expr in data.object_by_name.items(): + if name.startswith("cse"): + subst_rule(name, expr, visitor) + get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id) diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py index f8d1f7d0..76d169c9 100644 --- a/python/dune/perftool/ufl/visitor.py +++ b/python/dune/perftool/ufl/visitor.py @@ -5,6 +5,7 @@ to pymbolic and loopy. from dune.perftool.error import PerftoolUFLError from dune.perftool.generation import (domain, get_global_context_value, + subst_rule, ) from dune.perftool.ufl.flatoperators import get_operands from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker, @@ -45,9 +46,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): # Call base class constructors super(UFL2LoopyVisitor, self).__init__() - # Allow recursion through self.call(..) - call = MultiFunction.__call__ - def __call__(self, o): # Reset some state variables that are reinitialized for each accumulation term self.indices = None @@ -56,6 +54,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker): return self.call(o) + def call(self, o): + try: + return subst_rule(None, o, self) + except ValueError: + return MultiFunction.__call__(self, o) + # # Form argument/coefficients handlers: # This is where the actual domain specific work happens -- GitLab