From 8933451e4332780df4906af794dd32b33d384f3f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Tue, 16 Apr 2019 07:03:57 +0200
Subject: [PATCH] [skip ci][WIP] Rewrite loop transformation with direct
 accumulation

Still missing: Handling of reduction with haddsubst!
---
 .../dune/codegen/sumfact/transformations.py   | 331 +++++-------------
 1 file changed, 82 insertions(+), 249 deletions(-)

diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py
index 04beaa1d..d3765e55 100644
--- a/python/dune/codegen/sumfact/transformations.py
+++ b/python/dune/codegen/sumfact/transformations.py
@@ -15,162 +15,14 @@ from dune.codegen.error import CodegenAutotuneError
 from dune.codegen.sumfact.autotune import autotune_realization
 
 
-def move_zero_assignment_up(kernel, move_up_inames):
-    if len(move_up_inames) == 0:
-        return kernel
-
-    # Find the instruction we want to move around
-    cond = lp.match.Tagged('set_zero')
-    instructions = lp.find_instructions(kernel, cond)
-    move_iname_set = set(map(lambda x: prim.Variable(x), move_up_inames))
-    instr = None
-    for i in instructions:
-        instr_iname_set = set(i.assignee.index_tuple)
-        if move_iname_set.issubset(instr_iname_set):
-            # There should be only one matching instruction
-            assert instr is None
-            instr = i
-    assert instr is not None
-
-    # Remove it
-    kernel = lp.remove_instructions(kernel, set([instr.id]))
-
-    # Create loop domains: In order to move it upwards we need to create
-    # additional loops
-    iname_appendix = '_move_up'
-    domains = kernel.domains
-    for iname in move_up_inames:
-        # Find loop bound for this iname
-        for dom in domains:
-            if iname in dom.get_var_names(isl.dim_type.set):
-                # index = dom.get_var_names(isl.dim_type.set).index(iname)
-
-                # TODO: Noch unklar wie man die Loop bound aus isl rausbekommt.
-                todo_begin = str(dom).find(iname + ' =') + len(iname) + 3
-                if todo_begin == len(iname) + 3 - 1:
-                    todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4
-                todo_end = todo_begin + str(dom)[todo_begin:].find(' ')
-                loop_bound = int(str(dom)[todo_begin:todo_end]) + 1
-                break
-
-        domain = "{{ [{0}] : 0<={0}<{1} }}".format(iname + iname_appendix, loop_bound)
-        domain = lp.kernel.creation.parse_domains(domain, {})
-        domains = domains + domain
-
-    # Create tuple of correct inames for usage in  subscript below
-    indices = list(instr.assignee.index_tuple)
-    for i in range(len(indices)):
-        if indices[i].name in move_up_inames:
-            indices[i] = prim.Variable(indices[i].name + iname_appendix)
-    indices = tuple(indices)
-
-    # The new instructions needs to lie within those inames
-    within_inames = []
-    for i in indices:
-        within_inames.append(i.name)
-
-    # Create new instruction
-    assignee = prim.Subscript(instr.assignee.aggregate, indices)
-    instructions = []
-    instructions.append(instr.copy(assignee=assignee,
-                                   within_inames=frozenset(within_inames)))
-    kernel = kernel.copy(instructions=kernel.instructions + instructions,
-                         domains=domains)
-
-    # Add dependency to inner assignment instructions
-    cond = lp.match.Tagged('assignment')
-    assignment_instructions = lp.find_instructions(kernel, cond)
-    instr = None
-    for i in assignment_instructions:
-        instr_iname_set = set(i.assignee.index_tuple)
-        if move_iname_set.issubset(instr_iname_set):
-            # There should be only one matching instruction
-            assert instr is None
-            instr = i
-
-    id_zero = instructions[0].id
-    cond = lp.match.Id(instr.id)
-    kernel = lp.add_dependency(kernel, cond, id_zero)
-
-    return kernel
-
-
-def _reorder_loops_in_tensor_contraction_direct(kernel, iname_order):
-    """Reorder the loop nest of the tensor contractions
-
-    iname_order is a string that specifies the loop order. We use the following convention:
-
-    Each contraction in the sum factorization kernel has the form 'ij,jkl->kli'
-    using einsum notation from numpy. iname_order should be a string like
-    'iklj' if the loops should be done in order i, k, l, j.
-
-    Without transformations those loops will be done in the order lkij.
-
-    In the sum factorization kernel itself those inames are called:
-
-    sf_out_inames_2_* : l
-    sf_out_inames_1_* : k
-    sf_out_inames_0_* : i
-    red_* : j
-
-    where * represents the current direction (0,1,2 for 3D problems).
-
-    TODO: Maybe also support a list of the inames above?
-
-    TODO: Different order for different direction? Could make sense when we use
-    fastdg and a broadcast since the first contraction has a smaller input
-    matrix.
-    """
-    dim = world_dimension()
-    # TODO: In principle there is no need to be dimension dependent. I'm just
-    # not sure how to pass the iname_order in the general case. This probably
-    # needs a rework anyway so I just do the 3D case first.
-    assert dim == 3
-
-    kernel = remove_all_reductions(kernel)
-
-    # TODO: Doc after rewrite
-    reduction_iname = 'j'
-    iname_dict = {'l': 'sf_out_inames_2',
-                  'k': 'sf_out_inames_1',
-                  'i': 'sf_out_inames_0',
-                  'j': 'sf_red'}
-    reduction_index = iname_order.index(reduction_iname)
-    move_up_inames = list(map(lambda x: iname_dict[x], iname_order[reduction_index + 1:]))
-
-    # cond = lp.match.Tagged('set_zero')
-    cond = lp.match.Tagged('assignment')
-    instructions = lp.find_instructions(kernel, cond)
-    for instr in instructions:
-        inames = tuple(map(lambda x: x.name, instr.assignee.index_tuple))
-        current_move_up_inames = []
-        for i in inames:
-            for j in move_up_inames:
-                if i.find(j) >= 0:
-                    current_move_up_inames.append(i)
-
-        kernel = move_zero_assignment_up(kernel, current_move_up_inames)
-
-        # TODO: There should be a better method than searching the string for
-        # 'sf_red'. Unfortunately there are sometimes Call instructions due to
-        # broadcasts. That makes different ways difficult.
-        regex = re.compile('sf_red_([0-9]*)')
-        reduction_index = set(regex.findall(str(instr)))
-        assert len(reduction_index) == 1
-        reduction_index = reduction_index.pop()
-        reduction_iname = 'sf_red_{}'.format(reduction_index)
-
-        prefered_iname_order = []
-        for i in inames:
-            if i not in current_move_up_inames and i.find('vec') == -1:
-                prefered_iname_order.append(i)
-        prefered_iname_order.append(reduction_iname)
-        for i in current_move_up_inames:
-            prefered_iname_order.append(i)
-        prefered_iname_order = tuple(prefered_iname_order)
-        kernel = lp.prioritize_loops(kernel, prefered_iname_order)
-
-    return kernel
+def _current_iname_order(current_inames, new_iname_order):
+    """Sort the inames for this contraction according to new_iname order"""
+    current_iname_order = []
+    for i in new_iname_order:
+        for j in current_inames:
+            if i in j:
+                current_iname_order.append(j)
+    return current_iname_order
 
 
 def _get_inames_of_reduction(instr, iname_permutation):
@@ -211,77 +63,65 @@ def _get_inames_of_reduction(instr, iname_permutation):
     return outer_inames, reduction_iname, inner_inames, vec_inames
 
 
-def _duplicate_assignment_inames(kernel, match):
-    instructions = lp.find_instructions(kernel, match)
-    for instr in instructions:
-        assert isinstance(instr, lp.kernel.instruction.Assignment)
+def _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation):
+    """Reorder the loop nests of a tensor contraction accumulating directly in the data structure"""
+    dim = world_dimension()
 
-        # Dependencies
-        match = lp.match.Id(instr.id)
-        depends_on = instr.depends_on
-        depending = []
-        for i in kernel.instructions:
-            if instr.id in i.depends_on:
-                depending.append(i.id)
+    # Nothing to do if permutation is identity
+    if iname_permutation == tuple(range(dim + 1)):
+        return kernel
 
-        # Remove instruction
-        kernel = lp.remove_instructions(kernel, set([instr.id]))
+    # Use names used in sum factorization kernel (without the index that distinguishes the different directions)
+    default_iname_order = ['sf_out_inames_{}'.format(dim - 1 - i) for i in range(dim)] + ['sf_red']
+    from dune.codegen.sumfact.permutation import permute_backward
+    new_iname_order = permute_backward(default_iname_order, iname_permutation)
 
-        def _duplicate_name(iname):
-            iname_appendix = '_duplicate'
-            return iname.name + iname_appendix
+    kernel = remove_all_reductions(kernel)
 
-        agg_variable = kernel.temporary_variables[instr.assignee.aggregate.name]
-        vectorized = isinstance(agg_variable.dim_tags[-1], lp.kernel.array.VectorArrayDimTag)
-        if vectorized:
-            inames = instr.assignee.index_tuple[:-1]
+    for instr in kernel.instructions:
+        # Inames used in this reduction
+        outer_inames, reduction_iname, inner_inames, vec_inames = _get_inames_of_reduction(instr,
+                                                                                           iname_permutation)
+        if reduction_iname:
+            current_inames = outer_inames + [reduction_iname] + inner_inames + vec_inames
         else:
-            inames = instr.assignee.index_tuple
-
-        # Create new domains
-        domains = kernel.domains
-        new_domains = []
-        for iname in inames:
-            # Find loop bound for the corresponding domain
-            for dom in domains:
-                if iname.name in dom.get_var_names(isl.dim_type.set):
-                    # TODO There must be better way to get this information using isl
-                    str_begin = str(dom).find(iname.name + ' =') + len(iname.name) + 3
-                    if str_begin == len(iname.name) + 3 - 1:
-                        str_begin = str(dom).find(iname.name + ' <=') + len(iname.name) + 4
-                    str_end = str_begin + str(dom)[str_begin:].find(' ')
-                    loop_bound = int(str(dom)[str_begin:str_end]) + 1
-                    break
-
-            # Create new domain
-            domain = "{{ [{0}] : 0<={0}<{1} }}".format(_duplicate_name(iname), loop_bound)
-            domain = lp.kernel.creation.parse_domains(domain, {})
-            new_domains.append(domain)
-        for domain in new_domains:
-            domains = domains + domain
-
-        # Create new inames
-        new_inames = tuple(prim.Variable(_duplicate_name(i)) for i in inames)
-        if vectorized:
-            new_inames = new_inames + (instr.assignee.index_tuple[-1],)
+            current_inames = outer_inames + inner_inames + vec_inames
+        current_iname_order = _current_iname_order(current_inames,
+                                                   new_iname_order)
 
-        # Create new instruction within the new inames
-        assignee = prim.Subscript(instr.assignee.aggregate, new_inames)
-        new_instruction = instr.copy(assignee=assignee,
-                                     depends_on=depends_on,
-                                     within_inames=frozenset([i.name for i in new_inames]))
-        kernel = kernel.copy(instructions=kernel.instructions + [new_instruction],
-                             domains=domains)
+        if iname_permutation[-1] == dim:
+            kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
+            continue
 
-        # Restore dependencies
-        for dep in depending:
-            match = lp.match.Id(dep)
-            kernel = lp.add_dependency(kernel, match, new_instruction.id)
+        # palpo TODO
+        if 'haddsubst' in str(instr):
+            continue
+
+        # if  'assignment' in instr.tags or isinstance(instr.assignee, prim.Variable):
+        if 'assignment' in instr.tags:
+            # Set loop priority
+            lp.prioritize_loops(kernel, current_iname_order)
+        elif 'set_zero' in instr.tags:
+            # Duplicate inames and prioritize loops
+            duplicate_inames = tuple(i for i in inner_inames)
+            match = lp.match.Id(instr.id)
+            kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
+            # palpo TODO prioritize!
+        else:
+            # palpo TODO 2D?
+            # assert reduction_iname is None
+
+            # Duplicate inames and prioritize loops
+            duplicate_inames = tuple(i for i in inner_inames)
+            match = lp.match.Id(instr.id)
+            kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
+            # palpo TODO prioritize!
 
     return kernel
 
 
 def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation):
+    """Reorder the loop nests of a tensor contraction using an accumulation variable"""
     dim = world_dimension()
 
     # Nothing to do if permutation is identity
@@ -293,33 +133,21 @@ def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation):
     from dune.codegen.sumfact.permutation import permute_backward
     new_iname_order = permute_backward(default_iname_order, iname_permutation)
 
-    # Get the real names with direction indices in the right order
-    def _current_new_iname_order(outer, reduction, inner, new_iname_order):
-        if reduction:
-            reduction = [reduction]
-        else:
-            reduction = []
-        all_inames = outer + reduction + inner
-        current_iname_order = []
-        for i in new_iname_order:
-            for j in all_inames:
-                if i in j:
-                    current_iname_order.append(j)
-        return current_iname_order
-
     for instr in kernel.instructions:
         # Inames used in this reduction
         outer_inames, reduction_iname, inner_inames, vec_inames = _get_inames_of_reduction(instr,
                                                                                            iname_permutation)
+        if reduction_iname:
+            current_inames = outer_inames + [reduction_iname] + inner_inames + vec_inames
+        else:
+            current_inames = outer_inames + inner_inames + vec_inames
 
         # We can directly use lp.prioritize_loops if:
         # - The reduction is the innermost loop
         # - There is no reduction (eg reduced direction on faces)
         if iname_permutation[-1] == dim or reduction_iname is None:
-            current_iname_order = _current_new_iname_order(outer_inames,
-                                                           reduction_iname,
-                                                           inner_inames,
-                                                           new_iname_order)
+            current_iname_order = _current_iname_order(current_inames,
+                                                       new_iname_order)
             kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
             continue
         assert isinstance(instr.expression, lp.symbolic.Reduction)
@@ -434,24 +262,27 @@ def _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation):
             kernel = kernel.copy(temporary_variables=tv)
 
         # Reordering loops only works if we duplicate some inames
-        duplicate_inames = tuple(i.name for i in accum_init_inames)
-        if vectorized:
-            duplicate_inames = duplicate_inames[:-1]
+        duplicate_inames = tuple(inner_inames)
+
         match = lp.match.Id(accum_init_id)
         kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
+        match_inames = tuple(lp.find_instructions(kernel, match)[0].within_inames)
+        current_iname_order = _current_iname_order(match_inames, new_iname_order)
+        kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
+
         # Reorder loops of the assignment of the result
         if 'haddsubst' not in str(instr):
             match = lp.match.Id(assign_id)
-            kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
         else:
             match = lp.match.Id(assignment.id)
-            kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
+        kernel = lp.duplicate_inames(kernel, duplicate_inames, match)
+        match_inames = tuple(lp.find_instructions(kernel, match)[0].within_inames)
+        current_iname_order = _current_iname_order(match_inames, new_iname_order)
+        kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
 
         # Change loop order
-        current_iname_order = _current_new_iname_order(outer_inames,
-                                                       reduction_iname,
-                                                       inner_inames,
-                                                       new_iname_order)
+        current_iname_order = _current_iname_order(current_inames,
+                                                   new_iname_order)
         kernel = lp.prioritize_loops(kernel, tuple(current_iname_order))
 
     return kernel
@@ -509,8 +340,6 @@ def reorder_loops_in_tensor_contraction(kernel, iname_permutation, accum_variabl
         kernel = _reorder_loops_in_tensor_contraction_accum(kernel, iname_permutation)
         return kernel
     else:
-        # TODO: Need to adapt this!
-        assert False
         kernel = _reorder_loops_in_tensor_contraction_direct(kernel, iname_permutation)
         return kernel
 
@@ -524,10 +353,11 @@ def tensor_contraction_loop_order_generator(kernel):
         if permutation[0] == dim:
             continue
 
-        new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, True)
+        new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=True)
         yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_True'.format(permutation)]
 
-        # new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, False)
+        # palpo TODO
+        # new_kernel = reorder_loops_in_tensor_contraction(kernel, permutation, accum_variable=False)
         # yield new_kernel, ['reorder_loops_in_tensor_contraction_{}_False'.format(permutation),]
 
 
@@ -560,10 +390,13 @@ def sumfact_performance_transformations(kernel, signature):
             # # TODO
             # dim = world_dimension()
             # if dim == 2:
-            #     kernel = reorder_loops_in_tensor_contraction(kernel, (2,0,1), True)
+            #     # assert False
+            #     kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), True)
+            #     # kernel = reorder_loops_in_tensor_contraction(kernel, (2, 0, 1), False)
             # else:
-            #     kernel = reorder_loops_in_tensor_contraction(kernel, (3,2,0,1), True)
-            #     # kernel = reorder_loops_in_tensor_contraction(kernel, (1,2,0,3), True)
+            #     kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), True)
+            #     # kernel = reorder_loops_in_tensor_contraction(kernel, (1, 2, 0, 3), True)
+            #     # kernel = reorder_loops_in_tensor_contraction(kernel, (3, 2, 0, 1), False)
 
             kernel = autotune_tensor_contraction_loop_order(kernel, signature)
             pass
-- 
GitLab