From 3923568efbe35db58998ebb24cf7eef171a1bbc8 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 23 Jan 2017 15:53:04 +0100
Subject: [PATCH] Introduce a general infrastructure to detect common
 subexpressions

---
 python/dune/perftool/compile.py              |  8 +++++++-
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/loopy.py     | 10 ++++++++++
 python/dune/perftool/pdelab/localoperator.py |  8 ++++++++
 python/dune/perftool/ufl/visitor.py          | 10 +++++++---
 5 files changed, 33 insertions(+), 4 deletions(-)

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index 52f707ce..8e537016 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -10,7 +10,10 @@ import loopy
 from ufl.algorithms import compute_form_data, read_ufl_file
 from ufl.algorithms.formfiles import interpret_ufl_namespace
 
-from dune.perftool.generation import delete_cache_items, global_context
+from dune.perftool.generation import (delete_cache_items,
+                                      global_context,
+                                      subst_rule,
+                                      )
 from dune.perftool.interactive import start_interactive_session
 from dune.perftool.options import get_option
 from dune.perftool.pdelab.driver import generate_driver
@@ -82,6 +85,9 @@ def read_ufl(uflfile):
     # Enrich data by some additional objects we deem worth keeping
     if get_option("exact_solution_expression"):
         data.object_by_name[get_option("exact_solution_expression")] = namespace[get_option("exact_solution_expression")]
+    for name, expr in namespace.items():
+        if name.startswith("cse"):
+            data.object_by_name[name] = namespace[name]
 
     formdatas = []
     forms = data.forms
diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 3e043645..0af1aff7 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier,
                                             kernel_cached,
                                             noop_instruction,
                                             silenced_warning,
+                                            subst_rule,
                                             temporary_variable,
                                             transform,
                                             valuearg,
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 603f32f2..ac9f91bd 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -199,3 +199,13 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
     globalarg(name, **kwargs)
 
     return name
+
+
+@generator_factory(item_tags=("substrule",),
+                   context_tags="kernel",
+                   cache_key_generator=lambda n, e, v: e,
+                   )
+def subst_rule(name, expr, visitor):
+    if name is None:
+        raise ValueError("Requested a Substitution Rule that was not predefined!")
+    return lp.SubstitutionRule(name, (), visitor(expr))
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 09a12e77..3d39a243 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -19,6 +19,7 @@ from dune.perftool.generation import (backend,
                                       post_include,
                                       retrieve_cache_functions,
                                       retrieve_cache_items,
+                                      subst_rule,
                                       template_parameter,
                                       )
 from dune.perftool.cgen.clazz import (AccessModifier,
@@ -412,6 +413,13 @@ def visit_integrals(integrals):
                 interface = PDELabInterface()
             from dune.perftool.ufl.visitor import UFL2LoopyVisitor
             visitor = UFL2LoopyVisitor(interface, measure, indexmap)
+
+            # Inspect the UFL file for manual common subexpression elimination stuff!
+            data = get_global_context_value("data")
+            for name, expr in data.object_by_name.items():
+                if name.startswith("cse"):
+                    subst_rule(name, expr, visitor)
+
             get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
 
 
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index f8d1f7d0..76d169c9 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -5,6 +5,7 @@ to pymbolic and loopy.
 from dune.perftool.error import PerftoolUFLError
 from dune.perftool.generation import (domain,
                                       get_global_context_value,
+                                      subst_rule,
                                       )
 from dune.perftool.ufl.flatoperators import get_operands
 from dune.perftool.ufl.modified_terminals import (ModifiedTerminalTracker,
@@ -45,9 +46,6 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         # Call base class constructors
         super(UFL2LoopyVisitor, self).__init__()
 
-    # Allow recursion through self.call(..)
-    call = MultiFunction.__call__
-
     def __call__(self, o):
         # Reset some state variables that are reinitialized for each accumulation term
         self.indices = None
@@ -56,6 +54,12 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
 
         return self.call(o)
 
+    def call(self, o):
+        try:
+            return subst_rule(None, o, self)
+        except ValueError:
+            return MultiFunction.__call__(self, o)
+
     #
     # Form argument/coefficients handlers:
     # This is where the actual domain specific work happens
-- 
GitLab