From ee8ce79cc952fe5743b70176e72e5e86c69014d0 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Thu, 30 Mar 2017 15:14:09 +0200
Subject: [PATCH] Move permutation code to its own module

---
 python/dune/perftool/sumfact/permutation.py | 86 +++++++++++++++++++++
 python/dune/perftool/sumfact/sumfact.py     | 83 +-------------------
 2 files changed, 90 insertions(+), 79 deletions(-)
 create mode 100644 python/dune/perftool/sumfact/permutation.py

diff --git a/python/dune/perftool/sumfact/permutation.py b/python/dune/perftool/sumfact/permutation.py
new file mode 100644
index 00000000..34d32298
--- /dev/null
+++ b/python/dune/perftool/sumfact/permutation.py
@@ -0,0 +1,86 @@
+""" Permute sum factorization kernels """
+
+# TODO!
+# * get rid of the underscores in names
+# * Pass the entire kernel object into the strategy thing 
+
+import itertools
+
+
+def _sf_permutation_heuristic(permutations, stage):
+    """Heuristic to choose a permutation
+
+    - Stage 1: Pick the permutation where in permutations[1:] most
+      elements are ordered by size
+    - Stage 3: Pick the permutation where in permutations[:-1] most
+      elements are ordered by size
+    """
+    def cost(perm, stage):
+        cost = 0
+        for i in range(0, len(perm) - 2):
+            if stage == 1:
+                if perm[i + 1] > perm[i + 2]:
+                    cost += 1
+            if stage == 3:
+                if perm[0] > perm[i + 1]:
+                    cost += 1
+        return cost
+
+    perm = min(permutations, key=lambda i: cost(i, stage))
+    return perm
+
+
+def _sf_flop_cost(a_matrices):
+    """Computational cost of sumfactorization with this list of a_matrices
+    """
+    cost = 0
+    for l in range(len(a_matrices)):
+        cost_m = 1
+        cost_n = 1
+        for i in range(l + 1):
+            cost_m *= a_matrices[i].rows
+        for i in range(l, len(a_matrices)):
+            cost_n *= a_matrices[i].cols
+        cost += cost_m * cost_n
+    return cost
+
+
+def _sf_permutation_strategy(a_matrices, stage):
+    """Choose permutation of a_matrices list based on computational cost
+
+    Note: If there are multiple permutations with the same cost a
+    heuristic is used to pick one.
+    """
+    # Combine permutation and a_matrices list
+    perm = [i for i, _ in enumerate(a_matrices)]
+    perm_a_matrices = zip(perm, a_matrices)
+
+    # Find cost for all possible permutations of a_matrices list
+    perm_cost = []
+    for permutation in itertools.permutations(perm_a_matrices):
+        perm, series = zip(*permutation)
+        cost = _sf_flop_cost(series)
+        perm_cost.append((perm, cost))
+
+    # Find minimal cost and all permutations with that cost
+    _, costs = zip(*perm_cost)
+    minimal_cost = min(costs)
+    minimal_cost_permutations = [p[0] for p in perm_cost if p[1] == minimal_cost]
+
+    # Use heuristic to pic one of the minimal cost permutations
+    perm = _sf_permutation_heuristic(minimal_cost_permutations, stage)
+    return perm
+
+
+def _permute_forward(t, perm):
+    tmp = []
+    for pos in perm:
+        tmp.append(t[pos])
+    return tuple(tmp)
+
+
+def _permute_backward(t, perm):
+    tmp = [None] * len(t)
+    for i, pos in enumerate(perm):
+        tmp[pos] = t[i]
+    return tuple(tmp)
\ No newline at end of file
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 95fceb08..2500a69d 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -64,6 +64,10 @@ from pymbolic.primitives import (Call,
 from dune.perftool.sumfact.quadrature import quadrature_inames
 from dune.perftool.sumfact.vectorization import find_sumfact
 from loopy.symbolic import FunctionIdentifier, IdentityMapper
+from dune.perftool.sumfact.permutation import (_sf_permutation_strategy,
+                                               _permute_backward,
+                                               _permute_forward,
+                                               )
 
 import loopy as lp
 import numpy as np
@@ -360,85 +364,6 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
             insn_dep = emit_sumfact_kernel(None, restriction, insn_dep)
 
 
-def _sf_permutation_heuristic(permutations, stage):
-    """Heuristic to choose a permutation
-
-    - Stage 1: Pick the permutation where in permutations[1:] most
-      elements are ordered by size
-    - Stage 3: Pick the permutation where in permutations[:-1] most
-      elements are ordered by size
-    """
-    def cost(perm, stage):
-        cost = 0
-        for i in range(0, len(perm) - 2):
-            if stage == 1:
-                if perm[i + 1] > perm[i + 2]:
-                    cost += 1
-            if stage == 3:
-                if perm[0] > perm[i + 1]:
-                    cost += 1
-        return cost
-
-    perm = min(permutations, key=lambda i: cost(i, stage))
-    return perm
-
-
-def _sf_flop_cost(a_matrices):
-    """Computational cost of sumfactorization with this list of a_matrices
-    """
-    cost = 0
-    for l in range(len(a_matrices)):
-        cost_m = 1
-        cost_n = 1
-        for i in range(l + 1):
-            cost_m *= a_matrices[i].rows
-        for i in range(l, len(a_matrices)):
-            cost_n *= a_matrices[i].cols
-        cost += cost_m * cost_n
-    return cost
-
-
-def _sf_permutation_strategy(a_matrices, stage):
-    """Choose permutation of a_matrices list based on computational cost
-
-    Note: If there are multiple permutations with the same cost a
-    heuristic is used to pick one.
-    """
-    # Combine permutation and a_matrices list
-    perm = [i for i, _ in enumerate(a_matrices)]
-    perm_a_matrices = zip(perm, a_matrices)
-
-    # Find cost for all possible permutations of a_matrices list
-    perm_cost = []
-    for permutation in itertools.permutations(perm_a_matrices):
-        perm, series = zip(*permutation)
-        cost = _sf_flop_cost(series)
-        perm_cost.append((perm, cost))
-
-    # Find minimal cost and all permutations with that cost
-    _, costs = zip(*perm_cost)
-    minimal_cost = min(costs)
-    minimal_cost_permutations = [p[0] for p in perm_cost if p[1] == minimal_cost]
-
-    # Use heuristic to pic one of the minimal cost permutations
-    perm = _sf_permutation_heuristic(minimal_cost_permutations, stage)
-    return perm
-
-
-def _permute_forward(t, perm):
-    tmp = []
-    for pos in perm:
-        tmp.append(t[pos])
-    return tuple(tmp)
-
-
-def _permute_backward(t, perm):
-    tmp = [None] * len(t)
-    for i, pos in enumerate(perm):
-        tmp[pos] = t[i]
-    return tuple(tmp)
-
-
 @generator_factory(item_tags=("sumfactkernel",), context_tags=("kernel",), cache_key_generator=lambda a, b, s, **kw: (a, b, s, kw.get("restriction", 0)))
 def sum_factorization_kernel(a_matrices,
                              buf,
-- 
GitLab