From 77c56f647c05bdf0d1718f9995761fe6cbb62643 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 7 Sep 2017 11:22:43 +0200
Subject: [PATCH] Do not generate code for stage 1 sumfact kernels that don't
 get used

Save all stage 1 sum factorization kernels that are used in
accumulation expression in the cache during the dry run. Discard all
inactive sum factorization kernels in decide_vetorization_strategy.
---
 python/dune/perftool/sumfact/accumulation.py  | 23 +++++++++++++
 python/dune/perftool/sumfact/basis.py         | 10 ++++++
 python/dune/perftool/sumfact/realization.py   |  1 -
 python/dune/perftool/sumfact/vectorization.py | 33 ++++++++++++++-----
 4 files changed, 57 insertions(+), 10 deletions(-)

diff --git a/python/dune/perftool/sumfact/accumulation.py b/python/dune/perftool/sumfact/accumulation.py
index 50f4d169..c9c32b48 100644
--- a/python/dune/perftool/sumfact/accumulation.py
+++ b/python/dune/perftool/sumfact/accumulation.py
@@ -6,6 +6,7 @@ from dune.perftool.pdelab.argument import (name_accumulation_variable,
 from dune.perftool.generation import (backend,
                                       domain,
                                       dump_accumulate_timer,
+                                      generator_factory,
                                       get_counted_variable,
                                       get_counter,
                                       iname,
@@ -40,10 +41,29 @@ from pytools import ImmutableRecord
 import loopy as lp
 import numpy as np
 import pymbolic.primitives as prim
+from pymbolic.mapper import WalkMapper
 import ufl.classes as uc
 from ufl import FiniteElement, MixedElement, TensorProductElement
 
 
+basis_sf_kernels = generator_factory(item_tags=("basis_sf_kernels",), context_tags='kernel', no_deco=True)
+
+
+class SumfactCollectMapper(WalkMapper):
+    def map_tagged_variable(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_variable(self, expr, *args, **kwargs):
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+    def map_sumfact_kernel(self, expr, *args, **kwargs):
+        basis_sf_kernels(expr)
+        self.visit(expr, *args, **kwargs)
+        self.post_visit(expr, *args, **kwargs)
+
+
 @iname
 def _sumfact_iname(bound, _type, count):
     name = "sf_{}_{}".format(_type, str(count))
@@ -199,6 +219,9 @@ def generate_accumulation_instruction(expr, visitor):
     test_info = visitor.test_info
     trial_info = visitor.trial_info
 
+    # Cache all stage 1 sum factorization kernels used in this expression
+    SumfactCollectMapper()(expr)
+
     # Number of basis functions per direction
     leaf_element = test_info.element
     from ufl import MixedElement
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 8e745c7d..00ef2966 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -146,6 +146,11 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    # If this sum factorization kernel was not used in the dry run we
+    # just return 0
+    if vsf == 0:
+        return 0, None
+
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
     var, insn_dep = realize_sum_factorization_kernel(vsf)
 
@@ -182,6 +187,11 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor_indice
     from dune.perftool.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    # If this sum factorization kernel was not used in the dry run we
+    # just return 0
+    if vsf == 0:
+        return 0, None
+
     # Add a sum factorization kernel that implements the evaluation of
     # the basis functions at quadrature points (stage 1)
     from dune.perftool.sumfact.realization import realize_sum_factorization_kernel
diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index 3f1afe9c..5c9e4a5b 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -24,7 +24,6 @@ from dune.perftool.sumfact.permutation import (sumfact_permutation_strategy,
                                                permute_backward,
                                                permute_forward,
                                                )
-from dune.perftool.sumfact.vectorization import attach_vectorization_info
 from dune.perftool.sumfact.accumulation import sumfact_iname
 from dune.perftool.loopy.vcl import ExplicitVCLCast
 
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 08e07782..3a044794 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -140,37 +140,52 @@ def decide_vectorization_strategy():
     """
     logger = logging.getLogger(__name__)
 
+    # Retrieve all sum factorization kernels for stage 1 and 3
     from dune.perftool.generation import retrieve_cache_items
-    sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
+    all_sumfacts = [i for i in retrieve_cache_items("kernel_default and sumfactnodes")]
+
+    # Stage 1 sumfactorizations that were actually used
+    basis_sumfacts = [i for i in retrieve_cache_items('kernel_default and basis_sf_kernels')]
+
+    # This means we can have sum factorizations that will not get used
+    inactive_sumfacts = [i for i in all_sumfacts if i.stage == 1 and i not in basis_sumfacts]
+
+    # All sum factorization kernels that get used
+    active_sumfacts = [i for i in all_sumfacts if i.stage == 3 or i in basis_sumfacts]
+
+    # We map inacitve sum factorizatino kernels to 0
     sfdict = {}
+    for sf in inactive_sumfacts:
+        sfdict[sf] = 0
 
-    logger.debug("decide_vectorization_strategy: Found {} sum factorization nodes".format(len(sumfacts)))
+    logger.debug("decide_vectorization_strategy: Found {} active sum factorization nodes"
+                 .format(len(active_sumfacts)))
 
     if get_option("vectorize_grads"):
         # Currently we base our idea here on the fact that we only group sum
         # factorization kernels with the same input.
-        inputkeys = set(sf.input_key for sf in sumfacts)
+        inputkeys = set(sf.input_key for sf in active_sumfacts)
         for inputkey in inputkeys:
             width = get_vcl_type_size(np.float64)
-            sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
+            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
             for old, new in horizontal_vectorization_strategy(sumfact_filter, width).items():
                 sfdict[old] = new
     elif get_option("vectorize_slice"):
-        for sumfact in sumfacts:
+        for sumfact in active_sumfacts:
             width = get_vcl_type_size(np.float64)
             for old, new in vertical_vectorization_strategy(sumfact, width).items():
                 sfdict[old] = new
     elif get_option("vectorize_diagonal"):
-        inputkeys = set(sf.input_key for sf in sumfacts)
+        inputkeys = set(sf.input_key for sf in active_sumfacts)
         for inputkey in inputkeys:
             width = get_vcl_type_size(np.float64)
-            sumfact_filter = [sf for sf in sumfacts if sf.input_key == inputkey]
+            sumfact_filter = [sf for sf in active_sumfacts if sf.input_key == inputkey]
             for old, new in diagonal_vectorization_strategy(sumfact_filter, width).items():
                 sfdict[old] = new
     else:
-        for old, new in no_vectorization(sumfacts).items():
+        for old, new in no_vectorization(active_sumfacts).items():
             sfdict[old] = new
 
     # Register the results
-    for sf in sumfacts:
+    for sf in all_sumfacts:
         _cache_vectorization_info(sf, sfdict.get(sf, no_vec(sf)))
-- 
GitLab