From 444ce6cabd93a297934eea435be9f2a16e5e84d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Sun, 10 Feb 2019 21:39:40 +0100
Subject: [PATCH] [skip ci][WIP] Tensor contraction reordering transformation

This commit includes some ugly code that needs to be cleaned up!
---
 .../dune/codegen/sumfact/transformations.py   | 155 ++++++++++++++++++
 1 file changed, 155 insertions(+)
 create mode 100644 python/dune/codegen/sumfact/transformations.py

diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py
new file mode 100644
index 00000000..713d7fd9
--- /dev/null
+++ b/python/dune/codegen/sumfact/transformations.py
@@ -0,0 +1,155 @@
+import loopy as lp
+import pymbolic.primitives as prim
+import islpy as isl
+
+from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions
+from dune.codegen.pdelab.geometry import world_dimension
+
+def move_zero_assignment_up(knl, move_up_inames):
+    # Find the instruction we want to move around
+    cond = lp.match.Tagged('set_zero')
+    instructions = lp.find_instructions(knl, cond)
+    move_iname_set = set(map(lambda x: prim.Variable(x), move_up_inames))
+    instr = None
+    for i in instructions:
+        instr_iname_set = set(i.assignee.index_tuple)
+        if move_iname_set.issubset(instr_iname_set):
+            # There should be only one matching instruction
+            assert (instr==None)
+            instr = i
+    assert (instr!=None)
+
+    # Remove it
+    knl = lp.remove_instructions(knl, set([instr.id]))
+
+    # Create loop domains: In order to move it upwards we need to create
+    # additional loops
+    iname_appendix = '_move_up'
+    domains = knl.domains
+    for iname in move_up_inames:
+        # Find loop bound for this iname
+        for dom in domains:
+            if iname in dom.get_var_names(isl.dim_type.set):
+                # index = dom.get_var_names(isl.dim_type.set).index(iname)
+
+                # TODO: Noch unklar wie man die Loop bound aus isl rausbekommt.
+                todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4
+                todo_end = todo_begin + str(dom)[todo_begin:].find(' ')
+                loop_bound = int(str(dom)[todo_begin:todo_end]) + 1
+                break
+
+        domain = "{{ [{0}] : 0<={0}<{1} }}".format(iname + iname_appendix, loop_bound)
+        domain = lp.kernel.creation.parse_domains(domain, {})
+        domains = domains + domain
+
+    # Create tuple of correct inames for usage in  subscript below
+    indices = list(instr.assignee.index_tuple)
+    for i in range(len(indices)):
+        if indices[i].name in move_up_inames:
+            indices[i] = prim.Variable(indices[i].name + iname_appendix)
+    indices = tuple(indices)
+
+    # The new instructions needs to lie within those inames
+    within_inames = []
+    for i in indices:
+        within_inames.append(i.name)
+
+    # Create new instruction
+    assignee = prim.Subscript(instr.assignee.aggregate, indices)
+    instructions = []
+    instructions.append(instr.copy(assignee=assignee,
+                                   within_inames=frozenset(within_inames)))
+    knl = knl.copy(instructions=knl.instructions + instructions,
+                   domains=domains)
+
+    # Add dependency to inner assignment instructions
+    cond = lp.match.Tagged('assignment')
+    assignment_instructions = lp.find_instructions(knl, cond)
+    instr = None
+    for i in assignment_instructions:
+        instr_iname_set = set(i.assignee.index_tuple)
+        if move_iname_set.issubset(instr_iname_set):
+            # There should be only one matching instruction
+            assert (instr==None)
+            instr = i
+
+    id_zero = instructions[0].id
+    cond = lp.match.Id(instr.id)
+    knl = lp.add_dependency(knl, cond, id_zero)
+
+    return knl
+
+
+def reorder_loops_in_tensor_contraction(knl, iname_order):
+    """Reorder the loop nest of the tensor contractions
+
+    iname_order is a string that specifies the loop order. We use the following convention:
+
+    Each contraction in the sum factorization kernel has the form 'ij,jkl->kli'
+    using einsum notation from numpy. iname_order should be a string like
+    'iklj' if the loops should be done in order i, k, l, j.
+
+    In the sum factorization kernel itself those inames are called:
+
+    sf_out_inames_2_* : l
+    sf_out_inames_1_* : k
+    sf_out_inames_0_* : i
+    red_* : j
+
+    where * represents the current direction (0,1,2 for 3D problems).
+
+    TODO: Maybe also support a list of the inames above?
+
+    TODO: Different order for different direction? Could make sense when we use
+    fastdg and a broadcast since the first contraction has a smaller input
+    matrix.
+    """
+    dim = world_dimension()
+    # TODO: In principle there is no need to be dimension dependent. I'm just
+    # not sure how to pass the iname_order in the general case. This probably
+    # needs a rework anyway so I just do the 3D case first.
+    assert dim==3
+
+    knl = remove_all_reductions(knl)
+
+    # TODO: Doc after rewrite
+    reduction_iname = 'j'
+    iname_dict = { 'l' : 'sf_out_inames_2',
+                   'k' : 'sf_out_inames_1',
+                   'i' : 'sf_out_inames_0',
+                   'j' : 'sf_red'}
+    reduction_index = iname_order.index(reduction_iname)
+    move_up_inames = list(map(lambda x: iname_dict[x], iname_order[reduction_index+1:]))
+
+    # cond = lp.match.Tagged('set_zero')
+    cond = lp.match.Tagged('assignment')
+    instructions = lp.find_instructions(knl, cond)
+    for instr in instructions:
+        inames = tuple(map(lambda x: x.name, instr.assignee.index_tuple))
+        current_move_up_inames = []
+        for i in inames:
+            for j in move_up_inames:
+                if i.find(j) >= 0:
+                    current_move_up_inames.append(i)
+
+        knl = move_zero_assignment_up(knl, current_move_up_inames)
+
+        # TODO
+        #
+        # Finde the number appended to the inames of this contraction by taking
+        # all the number starting from the last '_'. There is definitely a more
+        # elegant way to find that ;).
+        sf_iname_index = int(inames[0][len(inames[0]) - inames[0][::-1].find('_'):])
+        reduction_iname = 'sf_red_{}'.format(sf_iname_index)
+
+        prefered_iname_order = []
+        for i in inames:
+            if i not in current_move_up_inames and i.find('vec') == -1:
+                prefered_iname_order.append(i)
+        prefered_iname_order.append(reduction_iname)
+        for i in current_move_up_inames:
+            prefered_iname_order.append(i)
+        prefered_iname_order = tuple(prefered_iname_order)
+        knl = lp.prioritize_loops(knl, prefered_iname_order)
+
+    return knl
-- 
GitLab