From e7a7fe6db6656a5f37678311de3653d0e95287da Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 28 Nov 2016 16:11:19 +0100
Subject: [PATCH] Delay addition of sum factorization kernel until all tree
 visiting is done

---
 python/dune/perftool/__init__.py             |  3 +-
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/loopy.py     |  1 +
 python/dune/perftool/loopy/symbolic.py       | 85 ++++++++++++++++++++
 python/dune/perftool/pdelab/localoperator.py | 12 ++-
 python/dune/perftool/sumfact/amatrix.py      |  3 +
 python/dune/perftool/sumfact/basis.py        | 22 ++---
 python/dune/perftool/sumfact/sumfact.py      | 68 +++++++++++++---
 8 files changed, 169 insertions(+), 26 deletions(-)
 create mode 100644 python/dune/perftool/loopy/symbolic.py

diff --git a/python/dune/perftool/__init__.py b/python/dune/perftool/__init__.py
index b907a800..f656fe92 100644
--- a/python/dune/perftool/__init__.py
+++ b/python/dune/perftool/__init__.py
@@ -1,4 +1,5 @@
-from dune.perftool.options import get_option
+# Trigger imports that involve monkey patching!
+import dune.perftool.loopy.symbolic
 
 # Trigger some imports that are needed to have all backend implementations visible
 # to the selection mechanisms
diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index 89a98fd7..93d3981f 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -29,6 +29,7 @@ from dune.perftool.generation.cpp import (base_class,
                                           )
 
 from dune.perftool.generation.loopy import (barrier,
+                                            built_instruction,
                                             constantarg,
                                             domain,
                                             function_mangler,
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index 5f20ec16..a5e073cd 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -15,6 +15,7 @@ iname = generator_factory(item_tags=("iname",), context_tags="kernel")
 function_mangler = generator_factory(item_tags=("mangler",), context_tags="kernel")
 silenced_warning = generator_factory(item_tags=("silenced_warning",), no_deco=True, context_tags="kernel")
 kernel_cached = generator_factory(item_tags=("default_cached",), context_tags="kernel")
+built_instruction = generator_factory(item_tags=("instruction",), context_tags="kernel", no_deco=True)
 
 
 class DuneGlobalArg(lp.GlobalArg):
diff --git a/python/dune/perftool/loopy/symbolic.py b/python/dune/perftool/loopy/symbolic.py
new file mode 100644
index 00000000..0c9d1f96
--- /dev/null
+++ b/python/dune/perftool/loopy/symbolic.py
@@ -0,0 +1,85 @@
+""" Monkey patches for loopy.symbolic
+
+Use this module to insert pymbolic nodes and the likes.
+"""
+from dune.perftool.error import PerftoolError
+from pymbolic.mapper.substitutor import make_subst_func
+
+import loopy as lp
+import pymbolic.primitives as prim
+
+
+#
+# Pymbolic nodes to insert into the symbolic language understood by loopy
+#
+
+
+class SumfactKernel(prim.Variable):
+    def __init__(self, a_matrices, buffer, insn_dep=frozenset({}), additional_inames=frozenset({})):
+        self.a_matrices = a_matrices
+        self.buffer = buffer
+        self.insn_dep = insn_dep
+        self.additional_inames = additional_inames
+
+        prim.Variable.__init__(self, "SUMFACT")
+
+    def __getinitargs__(self):
+        return (self.a_matrices, self.buffer, self.insn_dep, self.additional_inames)
+
+    def stringifier(self):
+        return lp.symbolic.StringifyMapper
+
+    init_arg_names = ("a_matrices", "buffer", "insn_dep", "additional_inames")
+
+    mapper_method = "map_sumfact_kernel"
+
+
+#
+# Mapper methods to monkey patch into the visitor base classes!
+#
+
+
+def identity_map_sumfact_kernel(self, expr, *args):
+    return expr
+
+
+def walk_map_sumfact_kernel(self, expr, *args):
+    self.visit(expr)
+
+
+def stringify_map_sumfact_kernel(self, expr, *args):
+    return "SUMFACT"
+
+
+def dependency_map_sumfact_kernel(self, expr):
+    return set()
+
+
+def needs_resolution(self, expr):
+    raise PerftoolError("SumfactKernel node is a placeholder and needs to be removed!")
+
+
+#
+# Do the actual monkey patching!!!
+#
+
+
+lp.symbolic.IdentityMapper.map_sumfact_kernel = identity_map_sumfact_kernel
+lp.symbolic.SubstitutionMapper.map_sumfact_kernel = lp.symbolic.SubstitutionMapper.map_variable
+lp.symbolic.WalkMapper.map_sumfact_kernel = walk_map_sumfact_kernel
+lp.symbolic.StringifyMapper.map_sumfact_kernel = stringify_map_sumfact_kernel
+lp.symbolic.DependencyMapper.map_sumfact_kernel = dependency_map_sumfact_kernel
+lp.target.c.codegen.expression.ExpressionToCExpressionMapper.map_sumfact_kernel = needs_resolution
+lp.type_inference.TypeInferenceMapper.map_sumfact_kernel = needs_resolution
+
+
+#
+# Some helper functions!
+#
+
+
+def substitute(expr, replacemap):
+    """ A replacement for pymbolic.mapper.subsitutor.substitute which is aware of all
+    monkey patches etc.
+    """
+    return lp.symbolic.SubstitutionMapper(make_subst_func(replacemap))(expr)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 30454153..fb36e575 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend,
                                       include_file,
                                       initializer_list,
                                       post_include,
+                                      retrieve_cache_functions,
                                       retrieve_cache_items,
                                       template_parameter,
                                       )
@@ -485,9 +486,14 @@ def generate_kernel(integrals):
 
 
 def extract_kernel_from_cache(tag):
-    # Extract the information, which is needed to create a loopy kernel.
-    # First extracting it, might be useful to alter it before kernel generation.
-    from dune.perftool.generation import retrieve_cache_functions, retrieve_cache_items
+    # Preprocess some instruction!
+    from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions
+    instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
+    for insn in instructions:
+        expand_sumfact_kernels(insn)
+    filter_sumfact_instructions()
+
+    # Now extract regular loopy kernel components
     from dune.perftool.loopy.target import DuneTarget
     domains = [i for i in retrieve_cache_items("{} and domain".format(tag))]
 
diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 0296a06a..1c786393 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -42,6 +42,9 @@ class AMatrix(Record):
                         cols=cols,
                         )
 
+    def __hash__(self):
+        return hash((self.a_matrix, self.rows, self.cols))
+
 
 def quadrature_points_per_direction():
     # TODO use quadrature order from dune.perftool.pdelab.quadrature
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index a54eb2e8..10021fd5 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -20,7 +20,7 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
                                            quadrature_points_per_direction,
                                            )
 from dune.perftool.sumfact.sumfact import (setup_theta,
-                                           sum_factorization_kernel,
+                                           SumfactKernel,
                                            sumfact_iname,
                                            )
 from dune.perftool.sumfact.quadrature import quadrature_inames
@@ -79,10 +79,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
         # evaluation of the gradients of basis functions at quadrature
         # points (stage 1)
         insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name)
-        var, _ = sum_factorization_kernel(a_matrices,
-                                          buffer_name,
-                                          insn_dep=frozenset({insn_dep}),
-                                          )
+        var = SumfactKernel(a_matrices,
+                            buffer_name,
+                            insn_dep=frozenset({insn_dep}),
+                            )
 
         buffers.append(var)
 
@@ -91,7 +91,7 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
         from pymbolic.primitives import Subscript, Variable
         from dune.perftool.generation import get_backend
         assignee = Subscript(Variable(name), i)
-        expression = Subscript(Variable(buf), tuple(Variable(i) for i in quadrature_inames()))
+        expression = Subscript(buf, tuple(Variable(i) for i in quadrature_inames()))
         instruction(assignee=assignee,
                     expression=expression,
                     forced_iname_deps=frozenset(get_backend("quad_inames")()),
@@ -133,12 +133,12 @@ def pymbolic_trialfunction(element, restriction, component):
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
     insn_dep = setup_theta(element, restriction, component, a_matrices, buffer_name)
-    var, _ = sum_factorization_kernel(a_matrices,
-                                      buffer_name,
-                                      insn_dep=frozenset({insn_dep}),
-                                      )
+    var = SumfactKernel(a_matrices,
+                        buffer_name,
+                        insn_dep=frozenset({insn_dep}),
+                        )
 
-    return prim.Subscript(prim.Variable(var),
+    return prim.Subscript(var,
                           tuple(prim.Variable(i) for i in quadrature_inames())
                           )
 
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index cfef8155..691aa1db 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -1,7 +1,6 @@
 import copy
 
-from pymbolic.mapper.substitutor import substitute
-
+from dune.perftool.loopy.symbolic import substitute
 from dune.perftool.pdelab.argument import (name_accumulation_variable,
                                            name_coefficientcontainer,
                                            pymbolic_coefficient,
@@ -9,6 +8,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable,
                                            )
 from dune.perftool.generation import (backend,
                                       barrier,
+                                      built_instruction,
                                       domain,
                                       function_mangler,
                                       get_counter,
@@ -37,6 +37,8 @@ from dune.perftool.sumfact.amatrix import (AMatrix,
                                            name_theta,
                                            name_theta_transposed,
                                            )
+from dune.perftool.loopy.symbolic import SumfactKernel
+from dune.perftool.error import PerftoolError
 from pymbolic.primitives import (Call,
                                  Product,
                                  Subscript,
@@ -46,9 +48,55 @@ from dune.perftool.sumfact.quadrature import quadrature_inames
 from loopy import Reduction, GlobalArg
 from loopy.symbolic import FunctionIdentifier
 
+import loopy as lp
+import pymbolic.primitives as prim
 from pytools import product
 
 
+class HasSumfactMapper(lp.symbolic.CombineMapper):
+    def combine(self, *args):
+        return frozenset().union(*tuple(*args))
+
+    def map_constant(self, expr):
+        return frozenset()
+
+    def map_algebraic_leaf(self, expr):
+        return frozenset()
+
+    def map_loopy_function_identifier(self, expr):
+        return frozenset()
+
+    def map_sumfact_kernel(self, expr):
+        return frozenset({expr})
+
+
+def find_sumfact(expr):
+    return HasSumfactMapper()(expr)
+
+
+def expand_sumfact_kernels(insn):
+    if isinstance(insn, (lp.Assignment, lp.CallInstruction)):
+        replace = {}
+        deps = []
+        for sumf in find_sumfact(insn.expression):
+            var, dep = sum_factorization_kernel(sumf.a_matrices, sumf.buffer, sumf.insn_dep, sumf.additional_inames)
+            replace[sumf] = prim.Variable(var)
+            deps.append(dep)
+
+        if replace:
+            built_instruction(insn.copy(expression=substitute(insn.expression, replace),
+                                        depends_on=frozenset(*deps)
+                                        )
+                              )
+
+
+def filter_sumfact_instructions():
+    """ Remove all instructions that contain a SumfactKernel node """
+    from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl
+    expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
+    call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
+
+
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
@@ -146,10 +194,9 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
         # Replace gradient iname with correct index for assignement
         replace_dict = {}
-        expression = copy.deepcopy(pymbolic_expr)
         for iname in additional_inames:
             replace_dict[Variable(iname)] = i
-        expression = substitute(expression, replace_dict)
+        expression = substitute(pymbolic_expr, replace_dict)
 
         # Issue an instruction in the quadrature loop that fills the buffer
         # with the evaluation of the contribution at all quadrature points
@@ -164,11 +211,11 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
 
         # Add a sum factorization kernel that implements the multiplication
         # with the test function (stage 3)
-        result, insn_dep = sum_factorization_kernel(a_matrices,
-                                                    buf,
-                                                    insn_dep=frozenset({contrib_dep}),
-                                                    additional_inames=frozenset(visitor.inames),
-                                                    )
+        result = SumfactKernel(a_matrices,
+                               buf,
+                               insn_dep=frozenset({contrib_dep}),
+                               additional_inames=frozenset(visitor.inames),
+                               )
 
         inames = tuple(sumfact_iname(mat.rows, 'accum') for mat in a_matrices)
 
@@ -193,7 +240,7 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
         expr = Call(PDELabAccumulationFunction(accum, rank),
                     (ansatz_lfs.get_args() +
                      test_lfs.get_args() +
-                     (Subscript(Variable(result), tuple(Variable(i) for i in inames)),)
+                     (Subscript(result, tuple(Variable(i) for i in inames)),)
                      )
                     )
 
@@ -201,7 +248,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                     expression=expr,
                     forced_iname_deps=frozenset(inames + visitor.inames),
                     forced_iname_deps_is_final=True,
-                    depends_on=insn_dep,
                     )
 
         # Mark the transformation that moves the quadrature loop inside the trialfunction loops for application
-- 
GitLab