From e2f7995f859c55125e013027c8c974c23cbe9965 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 19 Feb 2018 15:08:38 +0100
Subject: [PATCH] WIP

---
 python/dune/perftool/generation/__init__.py  |  1 +
 python/dune/perftool/generation/loopy.py     |  9 +++++++--
 python/dune/perftool/pdelab/localoperator.py | 14 +-------------
 python/dune/perftool/sumfact/realization.py  |  3 ++-
 python/dune/perftool/sumfact/symbolic.py     |  9 +++++++++
 5 files changed, 20 insertions(+), 16 deletions(-)

diff --git a/python/dune/perftool/generation/__init__.py b/python/dune/perftool/generation/__init__.py
index c8c085c1..e541e713 100644
--- a/python/dune/perftool/generation/__init__.py
+++ b/python/dune/perftool/generation/__init__.py
@@ -43,6 +43,7 @@ from dune.perftool.generation.loopy import (barrier,
                                             kernel_cached,
                                             noop_instruction,
                                             silenced_warning,
+                                            subst_rule,
                                             temporary_variable,
                                             transform,
                                             valuearg,
diff --git a/python/dune/perftool/generation/loopy.py b/python/dune/perftool/generation/loopy.py
index a97df474..df47d5e9 100644
--- a/python/dune/perftool/generation/loopy.py
+++ b/python/dune/perftool/generation/loopy.py
@@ -172,8 +172,8 @@ def noop_instruction(**kwargs):
                    context_tags="kernel",
                    cache_key_generator=no_caching,
                    )
-def transform(trafo, *args):
-    return (trafo, args)
+def transform(trafo, *args, **kwargs):
+    return (trafo, args, kwargs)
 
 
 @generator_factory(item_tags=("instruction", "barrier"),
@@ -216,3 +216,8 @@ def loopy_class_member(name, classtag=None, potentially_vectorized=False, **kwar
     globalarg(name, **kwargs)
 
     return name
+
+
+@generator_factory(item_tags=("substrule",), context_tags="kernel")
+def subst_rule(name, args, expr):
+    return lp.SubstitutionRule(name, args, expr)
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index 4873920b..c56f24e2 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -539,19 +539,7 @@ def extract_kernel_from_cache(tag, wrap_in_cgen=True):
 
     # Apply the transformations that were gathered during tree traversals
     for trafo in transformations:
-        kernel = trafo[0](kernel, *trafo[1])
-
-    # Precompute all the substrules
-    for sr in kernel.substitutions:
-        tmpname = "precompute_{}".format(sr)
-        kernel = lp.precompute(kernel,
-                               sr,
-                               temporary_name=tmpname,
-                               )
-        # Vectorization strategies are actually very likely to eliminate the
-        # precomputation temporary. To avoid the temporary elimination warning
-        # we need to explicitly disable it.
-        kernel = kernel.copy(silenced_warnings=kernel.silenced_warnings + ["temp_to_write({})".format(tmpname)])
+        kernel = trafo[0](kernel, *trafo[1], **trafo[2])
 
     from dune.perftool.loopy import heuristic_duplication
     kernel = heuristic_duplication(kernel)
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 4776f4fe..4b704c44 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -194,7 +194,7 @@ def _realize_sum_factorization_kernel(sf):
             tag = "{}_{}".format(tag, "_".join(sf.within_inames))
 
         # Collect the key word arguments for the loopy instruction
-        insn_args = {"forced_iname_deps": frozenset([iname for iname in out_inames]).union(frozenset(sf.within_inames)),
+        insn_args = {"forced_iname_deps": frozenset([i for i in out_inames]).union(frozenset(sf.within_inames)),
                      "forced_iname_deps_is_final": True,
                      "depends_on": insn_dep,
                      "tags": frozenset({tag}),
@@ -205,6 +205,7 @@ def _realize_sum_factorization_kernel(sf):
         # In case of direct output we directly accumulate the result
         # of the Sumfactorization into some global data structure.
         if l == len(matrix_sequence) - 1 and get_form_option('fastdg') and sf.stage == 3:
+            insn_args["forced_iname_deps"] = insn_args["forced_iname_deps"].union(frozenset({vec_iname[0].name}))
             insn_dep = sf.output.realize_direct(matprod, output_inames, out_shape, insn_args)
         else:
             # Issue the reduction instruction that implements the multiplication
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index 0256d86d..4ee4960e 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -2,6 +2,7 @@
 
 from dune.perftool.options import get_option
 from dune.perftool.generation import (get_counted_variable,
+                                      subst_rule,
                                       transform,
                                       )
 from dune.perftool.pdelab.geometry import local_dimension, world_dimension
@@ -128,6 +129,14 @@ class VectorSumfactKernelOutput(SumfactKernelOutputBase):
     def realize_direct(self, result, inames, shape, args):
         outputs = set(self.outputs)
 
+        # If multiple horizontal_add's are to be performed with 'result'
+        # we need to precompute the result!
+        if len(outputs) > 1:
+            substname = "haddsubst_{}".format("_".join([i.name for i in inames]))
+            subst_rule(substname, (), result)
+            result = prim.Call(prim.Variable(substname), ())
+            transform(lp.precompute, substname, precompute_outer_inames=args["forced_iname_deps"])
+
         deps = frozenset()
         for o in outputs:
             hadd_result = self._add_hadd(o, result)
-- 
GitLab