From f7705ce6e794eb14f302d9c533ee2a7a5d65b11b Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 3 Feb 2017 16:41:40 +0100
Subject: [PATCH] Refactor substitution rules to not show up in *all* kernels

But only when they are actually used!
---
 python/dune/perftool/compile.py              |  1 -
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/loopy.py     | 30 ++++++++++++++------
 python/dune/perftool/pdelab/localoperator.py | 10 +++----
 python/dune/perftool/ufl/visitor.py          | 10 ++++---
 5 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/python/dune/perftool/compile.py b/python/dune/perftool/compile.py
index 8e537016..d5427fa6 100644
--- a/python/dune/perftool/compile.py
+++ b/python/dune/perftool/compile.py
@@ -12,7 +12,6 @@ from ufl.algorithms.formfiles import interpret_ufl_namespace
 
 from dune.perftool.generation import (delete_cache_items,
                                       global_context,
-                                      subst_rule,
                                       )
 from dune.perftool.interactive import start_interactive_session
 from dune.perftool.options import get_option
diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index c7f16dfc..9d71f86c 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -33,6 +33,7 @@ from dune.perftool.generation.cpp import (base_class,
 from dune.perftool.generation.loopy import (barrier,
                                             built_instruction,
                                             constantarg,
+                                            construct_subst_rule,
                                             domain,
                                             function_mangler,
                                             get_temporary_name,
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index be835b7a..8bd656eb 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -201,15 +201,29 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
     return name
 
 
-@generator_factory(item_tags=("substrule",),
+@generator_factory(item_tags=("substrule_name",),
+                   context_tags="kernel",
+                   cache_key_generator=lambda e, n: e)
+def _substrule_name(expr, name):
+    return name
+
+
+@generator_factory(item_tags=("substrule_impl",),
                    context_tags="kernel",
-                   cache_key_generator=lambda e, r: e,
+                   cache_key_generator=lambda n, e, **ex: e,
                    )
-def subst_rule(expr, rule):
-    return rule
+def subst_rule(name, expr, exists=False):
+    _substrule_name(expr, name)
+    return exists
+
 
+def set_subst_rule(name, expr):
+    subst_rule(name, expr, exists=True)
 
-def set_subst_rule(name, expr, visitor):
-    rule = lp.SubstitutionRule(name, (), visitor(expr))
-    subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None}
-    return subst_rule(expr, rule)
+
+@generator_factory(item_tags=("substrule",),
+                   context_tags="kernel")
+def construct_subst_rule(expr, visitor):
+    name = _substrule_name(expr, None)
+    assert name
+    return lp.SubstitutionRule(name, (), visitor(expr, donot_check_substrules=expr))
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index adda9c7f..1f6c899b 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -419,16 +419,16 @@ def visit_integrals(integrals):
             data = get_global_context_value("data")
             for name, expr in data.object_by_name.items():
                 if name.startswith("cse"):
-                    set_subst_rule(name, expr, visitor)
+                    set_subst_rule(name, expr)
 
             # Ensure CSE on detjac * quadrature weight
             domain = term.argument.argexpr.ufl_domain()
             if measure == "cell":
-                set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor)
-                set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain), visitor)
+                set_subst_rule("integration_factor_cell1", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)))
+                set_subst_rule("integration_factor_cell2", uc.Abs(uc.JacobianDeterminant(domain))*uc.QuadratureWeight(domain))
             else:
-                set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor)
-                set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain), visitor)
+                set_subst_rule("integration_factor_facet1", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain))
+                set_subst_rule("integration_factor_facet2", uc.QuadratureWeight(domain)*uc.FacetJacobianDeterminant(domain))
 
             get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
 
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index f107d94b..c5e33a3d 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -3,7 +3,8 @@ This module defines the main visitor algorithm transforming ufl expressions
 to pymbolic and loopy.
 """
 from dune.perftool.error import PerftoolUFLError
-from dune.perftool.generation import (domain,
+from dune.perftool.generation import (construct_subst_rule,
+                                      domain,
                                       get_global_context_value,
                                       subst_rule,
                                       )
@@ -46,17 +47,18 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         # Call base class constructors
         super(UFL2LoopyVisitor, self).__init__()
 
-    def __call__(self, o):
+    def __call__(self, o, donot_check_substrules=None):
         # Reset some state variables that are reinitialized for each accumulation term
         self.indices = None
         self._indices_backup = []
         self.inames = ()
+        self.donot_check_substrules = donot_check_substrules
 
         return self.call(o)
 
     def call(self, o):
-        rule = subst_rule(o, None)
-        if rule:
+        if o != self.donot_check_substrules and subst_rule(None, o):
+            rule = construct_subst_rule(o, self)
             return prim.Call(prim.Variable(rule.name), ())
         else:
             return MultiFunction.__call__(self, o)
-- 
GitLab