From ea33081e78da642d2e9230d9bba1a7bd19a2a3ac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 16 Jan 2017 17:21:45 +0100
Subject: [PATCH] Permutation of directions is working, needs cleanup

---
 python/dune/perftool/sumfact/sumfact.py | 180 +++++++++++++++++++++++-
 1 file changed, 173 insertions(+), 7 deletions(-)

diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 0f2407d4..28791c17 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -1,4 +1,5 @@
 import copy
+import itertools
 
 from dune.perftool.loopy.symbolic import substitute
 from dune.perftool.pdelab.argument import (name_accumulation_variable,
@@ -328,6 +329,71 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
 
 
+def _sf_permutation_heuristic(permutations, stage):
+    """Heuristic to choose a permutation
+
+    - Stage 1: Pick the permutation where in permutations[1:] most
+      elements are ordered by size
+    - Stage 3: Pick the permutation where in permutations[:-1] most
+      elements are ordered by size
+    """
+    def cost(perm, stage):
+        cost = 0
+        for i in range(0,len(perm)-2):
+            if stage==1:
+                if perm[i+1]>perm[i+2]:
+                    cost += 1
+            if stage==3:
+                if perm[0]>perm[i+1]:
+                    cost += 1
+        return cost
+
+    perm = min(permutations, key=lambda i:cost(i,stage))
+    return perm
+
+
+def _sf_flop_cost(a_matrices):
+    """Computational cost of sumfactorization with this list of a_matrices
+    """
+    cost = 0;
+    for l in range(len(a_matrices)):
+        cost_m = 1
+        cost_n = 1
+        for i in range(l+1):
+            cost_m *= a_matrices[i].rows
+        for i in range(l,len(a_matrices)):
+            cost_n *= a_matrices[i].cols
+        cost += cost_m * cost_n
+    return cost
+
+
+def _sf_permutation_strategy(a_matrices, stage):
+    """Choose permutation of a_matices list based on computational cost
+
+    Note: If there are multiple permutations with the same cost a
+    heuristic is used to pick one.
+    """
+    # Combine permutation and a_matrices list
+    perm = [i for i, _ in enumerate(a_matrices)]
+    perm_a_matrices = zip(perm, a_matrices)
+
+    # Find cost for all possible permutations of a_matrices list
+    perm_cost = []
+    for permutation in itertools.permutations(perm_a_matrices):
+        perm, series = zip(*permutation)
+        cost = _sf_flop_cost(series)
+        perm_cost.append((perm,cost))
+
+    # Find minimal cost and all permutations with that cost
+    _, costs = zip(*perm_cost)
+    minimal_cost = min(costs)
+    minimal_cost_permutations = [p[0] for p in perm_cost if p[1]==minimal_cost]
+
+    # Use heuristic to pic one of the minimal cost permutations
+    perm = _sf_permutation_heuristic(minimal_cost_permutations, stage)
+    return perm
+
+
 @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
 def sum_factorization_kernel(a_matrices,
                              buf,
@@ -391,7 +457,6 @@ def sum_factorization_kernel(a_matrices,
     restriction: Restriction for faces values.
     direct_input: Global data structure containing input for
         sumfactorization (e.g. when using FastDGGridOperator).
-
     """
     if get_global_context_value("dry_run", False):
         return SumfactKernel(a_matrices, buf, stage, preferred_position, restriction), frozenset()
@@ -418,6 +483,46 @@ def sum_factorization_kernel(a_matrices,
                                   within_inames=additional_inames,
                                   )})
 
+    # Decide in which order we want to process directions in the
+    # sumfactorization. A clever ordering can lead to a reduced
+    # complexity. This will e.g. happen at faces where we only have
+    # one quadratue point m_l=1 if l is the normal direction of the
+    # face.
+    #
+    # Rule of thumb: small m's early and large n's late.
+
+    # palpo TODO
+    if stage==3 and outshape!=None:
+        from IPython import embed; embed(); import sys; sys.exit("Error message")
+
+    if stage==1 or stage==3:
+        # perm = range(len(a_matrices))
+        perm = _sf_permutation_strategy(a_matrices, stage)
+    else:
+        perm = range(len(a_matrices))
+
+    # # palpo TODO
+    # print("##  PALPO")
+    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
+    # print(shape)
+
+    # Permute a_matrices
+    new_a_matrices = []
+    for pos in perm:
+        new_a_matrices.append(a_matrices[pos])
+    a_matrices = tuple(new_a_matrices)
+
+    # # palpo TODO
+    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
+    # print(shape)
+
+    # new_a_matrices = [None]*len(a_matrices)
+    # for i, pos in enumerate(perm):
+    #     new_a_matrices[pos] = a_matrices[i]
+    # a_matrices = tuple(new_a_matrices)
+    # shape = [(mat.rows,mat.cols) for mat in a_matrices]
+    # print(shape)
+
     # Product of all matrices
     for l, a_matrix in enumerate(a_matrices):
         # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication
@@ -449,16 +554,43 @@ def sum_factorization_kernel(a_matrices,
         if l == 0 and direct_input is not None:
             globalarg(direct_input, dtype=np.float64, shape=inp_shape)
             if a_matrix.vectorized:
+                # palpo TODO
+                assert(False)
                 input_summand = prim.Call(prim.Variable("Vec4d"),
                                           (prim.Subscript(prim.Variable(direct_input),
                                                           (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])),))
             else:
+                # palpo TODO
+                assert(False)
                 input_summand = prim.Subscript(prim.Variable(direct_input),
-                                               (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) + vec_iname)
+                                               palpo + vec_iname)
         else:
             # Get a temporary that interprets the base storage of the input
             # as a column-major matrix. In later iteration of the amatrix loop
             # this reinterprets the output of the previous iteration.
+            palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
+            if l==0:
+                tmp_perm = [None]*len(inp_shape)
+                for i, pos in enumerate(perm):
+                    tmp_perm[pos] = inp_shape[i]
+                inp_shape = tuple(tmp_perm)
+
+                tmp_perm = [None]*len(palpo)
+                for i, pos in enumerate(perm):
+                    tmp_perm[pos] = palpo[i]
+                palpo = tuple(tmp_perm)
+
+            # tmp_perm = []
+            # for pos in perm:
+            #     tmp_perm.append(inp_shape[pos])
+            # inp_shape = tuple(tmp_perm)
+
+            # palpo = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
+            # tmp_perm = []
+            # for pos in perm:
+            #     tmp_perm.append(palpo[pos])
+            # palpo = tuple(tmp_perm)
+
             inp = get_buffer_temporary(buf,
                                        shape=inp_shape + vec_shape,
                                        dim_tags=ftags)
@@ -467,7 +599,7 @@ def sum_factorization_kernel(a_matrices,
             silenced_warning('read_no_write({})'.format(inp))
 
             input_summand = prim.Subscript(prim.Variable(inp),
-                                           (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:]) + vec_iname)
+                                           palpo + vec_iname)
 
         switch_base_storage(buf)
 
@@ -478,9 +610,27 @@ def sum_factorization_kernel(a_matrices,
         # corresponding shape (out_shape[0]) goes to the end (slowest
         # direction) and everything stays column major (ftags->fortran
         # style).
-        out = get_buffer_temporary(buf,
-                                   shape=tuple(out_shape[1:]) + (out_shape[0],) + vec_shape,
-                                   dim_tags=ftags)
+        # if False:
+        if l == len(a_matrices)-1:
+            out_shape = tuple(out_shape[1:]) + (out_shape[0],)
+            tmp_perm = [None]*len(out_shape)
+            for i, pos in enumerate(perm):
+                tmp_perm[pos] = out_shape[i]
+            out_shape = tuple(tmp_perm)
+
+            # tmp_perm = []
+            # for pos in perm:
+            #     tmp_perm.append(out_shape[pos])
+            # out_shape = tuple(tmp_perm)
+
+
+            out = get_buffer_temporary(buf,
+                                       shape=out_shape + vec_shape,
+                                       dim_tags=ftags)
+        else:
+            out = get_buffer_temporary(buf,
+                                       shape=tuple(out_shape[1:]) + (out_shape[0],) + vec_shape,
+                                       dim_tags=ftags)
 
         # Write the matrix-matrix multiplication expression
         # matprod = Product((prim.Subscript(prim.Variable(a_matrix.name),
@@ -495,7 +645,23 @@ def sum_factorization_kernel(a_matrices,
             matprod = lp.Reduction("sum", k, matprod)
 
         # Here we also move the new direction (out_inames[0]) to the end
-        assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname)
+        # if False:
+        if l == len(a_matrices)-1:
+            palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
+            tmp_perm = [None]*len(palpo)
+            for i, pos in enumerate(perm):
+                tmp_perm[pos] = palpo[i]
+            palpo = tuple(tmp_perm)
+
+            # palpo = tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),)
+            # tmp_perm = []
+            # for pos in perm:
+            #     tmp_perm.append(palpo[pos])
+            # palpo = tuple(tmp_perm)
+
+            assignee = prim.Subscript(prim.Variable(out), palpo + vec_iname)
+        else:
+            assignee = prim.Subscript(prim.Variable(out), tuple(prim.Variable(i) for i in out_inames[1:]) + (prim.Variable(out_inames[0]),) + vec_iname)
 
         # Issue the reduction instruction that implements the multiplication
         # at the same time store the instruction ID for the next instruction to depend on
-- 
GitLab