From 0bc540722b2e4cf0172598afd4abe4148243792a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 24 Jan 2017 13:08:45 +0100
Subject: [PATCH] First implementation of manual CSE

---
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/cache.py     |  5 +----
 python/dune/perftool/generation/loopy.py     | 15 ++++++++++-----
 python/dune/perftool/pdelab/localoperator.py | 19 ++++++++++++-------
 python/dune/perftool/ufl/visitor.py          | 11 ++++++-----
 5 files changed, 30 insertions(+), 21 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 0af1aff7..c7f16dfc 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier,
                                             kernel_cached,
                                             noop_instruction,
                                             silenced_warning,
+                                            set_subst_rule,
                                             subst_rule,
                                             temporary_variable,
                                             transform,
diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py
index f0823bb5..29e128e0 100644
--- a/python/dune/perftool/generation/cache.py
+++ b/python/dune/perftool/generation/cache.py
@@ -107,7 +107,7 @@ class _RegisteredFunction(object):
             # get an unknown keyword...
             try:
                 val = self.on_store(self.func(*args, **kwargs))
-            except:
+            except TypeError:
                 val = self.on_store(self.func(*args, **without_context))
 
             # Maybe wrap it with a counter!
@@ -119,9 +119,6 @@ class _RegisteredFunction(object):
         # Return the result for immediate usage
         return self._get_content(cache_key)
 
-    def remove_by_value(self, val):
-        self._memoize_cache = {k: v for k, v in self._memoize_cache.items() if v != val}
-
 
 def generator_factory(**factory_kwargs):
     """ A function decorator factory
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index ac9f91bd..b618f0bc 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -203,9 +203,14 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
 
 @generator_factory(item_tags=("substrule",),
                    context_tags="kernel",
-                   cache_key_generator=lambda n, e, v: e,
+                   cache_key_generator=lambda e, r: e,
                    )
-def subst_rule(name, expr, visitor):
-    if name is None:
-        raise ValueError("Requested a Substitution Rule that was not predefined!")
-    return lp.SubstitutionRule(name, (), visitor(expr))
+def subst_rule(expr, rule):
+    return rule
+
+
+def set_subst_rule(name, expr, visitor):
+    rule = lp.SubstitutionRule(name, (), visitor(expr))
+    subst_rule._memoize_cache = {k: v for k, v in subst_rule._memoize_cache.items() if v is not None}
+    print [type(v) for v in subst_rule._memoize_cache.values()]
+    return subst_rule(expr, rule)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 5d53d25d..ecd2f787 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -19,7 +19,7 @@ from dune.perftool.generation import (backend,
                                       post_include,
                                       retrieve_cache_functions,
                                       retrieve_cache_items,
-                                      subst_rule,
+                                      set_subst_rule,
                                       template_parameter,
                                       )
 from dune.perftool.cgen.clazz import (AccessModifier,
@@ -419,14 +419,14 @@ def visit_integrals(integrals):
             data = get_global_context_value("data")
             for name, expr in data.object_by_name.items():
                 if name.startswith("cse"):
-                    subst_rule(name, expr, visitor)
+                    set_subst_rule(name, expr, visitor)
 
             # Ensure CSE on detjac * quadrature weight
             domain = term.argument.argexpr.ufl_domain()
             if term.argument.restriction:
-                subst_rule("integration_factor", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor)
+                set_subst_rule("integration_factor", uc.FacetJacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor)
             else:
-                subst_rule("integration_factor", uc.JacobianDeterminant(domain)*uc.QuadratureWeight(domain), visitor)
+                set_subst_rule("integration_factor", uc.QuadratureWeight(domain)*uc.Abs(uc.JacobianDeterminant(domain)), visitor)
 
             get_backend(interface="accum_insn")(visitor, term, measure, subdomain_id)
 
@@ -467,6 +467,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
         domains = ["{[stupid] : 0<=stupid<1}"]
 
     instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
+    substrules = [i for i in retrieve_cache_items("{} and substrule".format(tag)) if i is not None]
     temporaries = {i.name: i for i in retrieve_cache_items("{} and temporary".format(tag))}
     arguments = [i for i in retrieve_cache_items("{} and argument".format(tag))]
     silenced = [l for l in retrieve_cache_items("{} and silenced_warning".format(tag))]
@@ -486,7 +487,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
     # Create the kernel
     from loopy import make_kernel, preprocess_kernel
     kernel = make_kernel(domains,
-                         instructions,
+                         instructions + substrules,
                          arguments,
                          temporary_variables=temporaries,
                          target=DuneTarget(),
@@ -494,14 +495,18 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
                          silenced_warnings=silenced,
                          name=name,
                          )
-
     from loopy import make_reduction_inames_unique
-    kernel = make_reduction_inames_unique(kernel)
+    #kernel = make_reduction_inames_unique(kernel)
 
     # Apply the transformations that were gathered during tree traversals
     for trafo in transformations:
         kernel = trafo[0](kernel, *trafo[1])
 
+    # Precompute all the substrules
+    for sr in substrules:
+        print sr.name
+        kernel = lp.precompute(kernel, sr.name)
+
     from dune.perftool.loopy import heuristic_duplication
     kernel = heuristic_duplication(kernel)
 
diff --git a/python/dune/perftool/ufl/visitor.py b/python/dune/perftool/ufl/visitor.py
index 76d169c9..1ba817c8 100644
--- a/python/dune/perftool/ufl/visitor.py
+++ b/python/dune/perftool/ufl/visitor.py
@@ -55,9 +55,10 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         return self.call(o)
 
     def call(self, o):
-        try:
-            return subst_rule(None, o, self)
-        except ValueError:
+        rule = subst_rule(o, None)
+        if rule:
+            return prim.Call(prim.Variable(rule.name), ())
+        else:
             return MultiFunction.__call__(self, o)
 
     #
@@ -226,7 +227,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
     #
 
     def product(self, o):
-        return Product(tuple(self.call(op) for op in get_operands(o)))
+        return Product(tuple(self.call(op) for op in o.ufl_operands))
 
     def float_value(self, o):
         return o.value()
@@ -238,7 +239,7 @@ class UFL2LoopyVisitor(ModifiedTerminalTracker):
         return Quotient(self.call(o.ufl_operands[0]), self.call(o.ufl_operands[1]))
 
     def sum(self, o):
-        return Sum(tuple(self.call(op) for op in get_operands(o)))
+        return Sum(tuple(self.call(op) for op in o.ufl_operands))
 
     def zero(self, o):
         return 0
-- 
GitLab