From 22800cf2da1fdbcf3037dfe3ee350b5549df136b Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 1 Dec 2016 17:26:08 +0100
Subject: [PATCH] Fix symdiffs!

---
 python/dune/perftool/generation/cache.py     |  3 ++
 python/dune/perftool/pdelab/localoperator.py |  6 +--
 python/dune/perftool/sumfact/sumfact.py      | 40 ++++++++++++--------
 3 files changed, 30 insertions(+), 19 deletions(-)

diff --git a/python/dune/perftool/generation/cache.py b/python/dune/perftool/generation/cache.py
index 58b67988..e4c3fb59 100644
--- a/python/dune/perftool/generation/cache.py
+++ b/python/dune/perftool/generation/cache.py
@@ -119,6 +119,9 @@ class _RegisteredFunction(object):
         # Return the result for immediate usage
         return self._get_content(cache_key)
 
+    def remove_by_value(self, val):
+        self._memoize_cache = {k:v for k, v in self._memoize_cache.items() if v != val}
+
 
 def generator_factory(**factory_kwargs):
     """ A function decorator factory
diff --git a/python/dune/perftool/pdelab/localoperator.py b/python/dune/perftool/pdelab/localoperator.py
index c1455f35..8610952b 100644
--- a/python/dune/perftool/pdelab/localoperator.py
+++ b/python/dune/perftool/pdelab/localoperator.py
@@ -488,10 +488,8 @@ def generate_kernel(integrals):
 
 def extract_kernel_from_cache(tag):
     # Preprocess some instruction!
-    from dune.perftool.sumfact.sumfact import expand_sumfact_kernels, filter_sumfact_instructions
-    instructions = [i for i in retrieve_cache_items("{} and instruction".format(tag))]
-    expand_sumfact_kernels(instructions)
-    filter_sumfact_instructions()
+    from dune.perftool.sumfact.sumfact import expand_sumfact_kernels
+    expand_sumfact_kernels(tag)
 
     # Now extract regular loopy kernel components
     from dune.perftool.loopy.target import DuneTarget
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index e3042690..c6bc6254 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -16,6 +16,7 @@ from dune.perftool.generation import (backend,
                                       globalarg,
                                       iname,
                                       instruction,
+                                      retrieve_cache_items,
                                       silenced_warning,
                                       temporary_variable,
                                       transform,
@@ -139,6 +140,12 @@ def default_resolution(insns):
                                             depends_on=frozenset(*deps)
                                             )
                                   )
+                if isinstance(insn, lp.Assignment):
+                    from dune.perftool.generation.loopy import expr_instruction_impl
+                    expr_instruction_impl.remove_by_value(insn)
+                if isinstance(insn, lp.CallInstruction):
+                    from dune.perftool.generation.loopy import call_instruction_impl
+                    call_instruction_impl.remove_by_value(insn)
 
 
 def apply_sumfact_grad_vectorization(insns, stage):
@@ -153,7 +160,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
 
     # Now apply some heuristics when to vectorize...
     if len(set(sumfact_kernels)) < 3:
-        default_resolution(insns)
+        return
     else:
         # Vectorize!!!
         sumfact_kernels = sorted(sumfact_kernels, key=lambda s: s.preferred_interleaving_position)
@@ -167,6 +174,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
 
         # Maybe initialize the input buffer
         if sumfact_kernels[0].setup_method:
+            assert stage == 1
             shape = product(mat.cols for mat in sumfact_kernels[0].a_matrices)
             shape = (shape, 4)
             initialize_buffer(buffer, base_storage_size=4 * shape[0])
@@ -178,6 +186,7 @@ def apply_sumfact_grad_vectorization(insns, stage):
                 func, args = sumf.setup_method
                 insn_dep = insn_dep.union({func(inp, *args, additional_indices=(i,))})
         else:
+            assert stage == 3
             # No setup method defined. We need to make sure the input is correctly setup
             shape = tuple(mat.cols for mat in sumfact_kernels[0].a_matrices) + (4,)
             initialize_buffer(buffer, base_storage_size=4 * shape[0])
@@ -189,9 +198,11 @@ def apply_sumfact_grad_vectorization(insns, stage):
                 for insn in insns:
                     if isinstance(insn, lp.Assignment):
                         if get_pymbolic_basename(insn.assignee) == sumf.input_temporary:
-                            built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,))))
+                            built_instruction(insn.copy(assignee=prim.Subscript(prim.Variable(inp), insn.assignee.index + (i,)),
+                                                        id=insn.id + "__{}".format(i)))
+                            insn_dep = insn_dep.union(insn.depends_on)
                             from dune.perftool.generation.loopy import expr_instruction_impl
-                            expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if v.id != insn.id}
+                            expr_instruction_impl.remove_by_value(insn)
 
         # Determine the joined AMatrix
         large_a_matrices = []
@@ -225,19 +236,19 @@ def apply_sumfact_grad_vectorization(insns, stage):
                                                     depends_on=dep,
                                                     ))
 
-def expand_sumfact_kernels(insns):
-    if get_option("vectorize_grads"):
-        apply_sumfact_grad_vectorization(insns, 1)
-        apply_sumfact_grad_vectorization(insns, 3)
-    else:
-        default_resolution(insns)
+                        if isinstance(insn, lp.Assignment):
+                            from dune.perftool.generation.loopy import expr_instruction_impl
+                            expr_instruction_impl.remove_by_value(insn)
+                        if isinstance(insn, lp.CallInstruction):
+                            from dune.perftool.generation.loopy import call_instruction_impl
+                            call_instruction_impl.remove_by_value(insn)
 
 
-def filter_sumfact_instructions():
-    """ Remove all instructions that contain a SumfactKernel node """
-    from dune.perftool.generation.loopy import expr_instruction_impl, call_instruction_impl
-    expr_instruction_impl._memoize_cache = {k: v for k, v in expr_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
-    call_instruction_impl._memoize_cache = {k: v for k, v in call_instruction_impl._memoize_cache.items() if not find_sumfact(v.expression)}
+def expand_sumfact_kernels(tag):
+    if get_option("vectorize_grads"):
+        apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 1)
+        apply_sumfact_grad_vectorization([i for i in retrieve_cache_items("{} and instruction".format(tag))], 3)
+    default_resolution(retrieve_cache_items("{} and instruction".format(tag)))
 
 
 @iname
@@ -389,7 +400,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
                      (Subscript(result, tuple(Variable(i) for i in inames)),)
                      )
                     )
-
         instruction(assignees=(),
                     expression=expr,
                     forced_iname_deps=frozenset(inames + visitor.inames),
-- 
GitLab