From 1878563e87db3bc578fa1a31513ae5abf9f92304 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 8 Feb 2018 16:50:39 +0100
Subject: [PATCH] Fix minimization over number of quadrature points

So far, it did not set the qp tuple correctly and
thus sometimes resulted in too big numbers being
chosen.
---
 python/dune/perftool/sumfact/vectorization.py | 22 ++++++++++++-------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 38ec9b16..d7fb6776 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -95,10 +95,12 @@ def explicit_costfunction(sf):
         return 1000000000000
 
 
-def strategy_cost(strategy):
+def strategy_cost(strat_tuple):
+    qp, strategy = strat_tuple
     func = get_backend(interface="vectorization_strategy",
                        selector=lambda: get_form_option("vectorization_strategy"))
     keys = set(sf.cache_key for sf in strategy.values())
+    set_quadrature_points(qp)
 
     # Sum over all the sum factorization kernels in the realization
     score = 0.0
@@ -110,6 +112,13 @@ def strategy_cost(strategy):
     return score
 
 
+def fixedqp_strategy_costfunction(qp):
+    def _cost(strategy):
+        return strategy_cost((qp, strategy))
+
+    return _cost
+
+
 def stringify_vectorization_strategy(strategy):
     result = []
     qp, strategy = strategy
@@ -189,11 +198,10 @@ def decide_vectorization_strategy():
             for qp in quad_points:
                 for strat in fixed_quad_vectorization_opportunity_generator(frozenset(stage1_sumfacts), width, qp):
                     if strategy == int(get_form_option("vectorization_list_index")):
-                        set_quadrature_points(qp)
                         # Output the strategy and its cost into a separate file
                         if get_global_context_value("form_type") == "jacobian_apply":
                             with open("strategycosts.csv", "a") as f:
-                                f.write("{} {}\n".format(strategy, strategy_cost(strat)))
+                                f.write("{} {}\n".format(strategy, strategy_cost((qp, strat))))
                         return qp, strat
                     strategy = strategy + 1
 
@@ -210,12 +218,12 @@ def decide_vectorization_strategy():
         for key in keys:
             key_sumfacts = frozenset(sf for sf in active_sumfacts if sf.input_key == key)
             minimum = min(fixed_quad_vectorization_opportunity_generator(key_sumfacts, width, qp),
-                          key=strategy_cost)
+                          key=fixedqp_strategy_costfunction(qp))
             sfdict = add_to_frozendict(sfdict, minimum)
     else:
         # Find the minimum cost strategy between all the quadrature point tuples
         optimal_strategies = {qp: fixed_quadrature_optimal_vectorization(active_sumfacts, width, qp) for qp in quad_points}
-        qp = min(optimal_strategies, key=lambda qp: strategy_cost(optimal_strategies[qp]))
+        qp = min(optimal_strategies, key=lambda qp: strategy_cost((qp, optimal_strategies[qp])))
         sfdict = optimal_strategies[qp]
 
     set_quadrature_points(qp)
@@ -238,8 +246,6 @@ def fixed_quadrature_optimal_vectorization(sumfacts, width, qp):
     opportunities and score them individually, but we need to do a divide and conquer
     approach.
     """
-    set_quadrature_points(qp)
-
     # Find the sets of simultaneously realizable kernels (thats an equivalence relation)
     keys = frozenset(sf.input_key for sf in sumfacts)
 
@@ -249,7 +255,7 @@ def fixed_quadrature_optimal_vectorization(sumfacts, width, qp):
     for key in keys:
         key_sumfacts = frozenset(sf for sf in sumfacts if sf.input_key == key)
         minimum = min(fixed_quad_vectorization_opportunity_generator(key_sumfacts, width, qp),
-                      key=strategy_cost)
+                      key=fixedqp_strategy_costfunction(qp))
         sfdict = add_to_frozendict(sfdict, minimum)
 
     return sfdict
-- 
GitLab