From 570692240ac2c7a0c74b38af49d5f42000957dad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 4 Apr 2019 11:39:51 +0200
Subject: [PATCH] [skip ci] Loop reordering with accumulation variable

---
 .../dune/codegen/sumfact/transformations.py   | 114 +++++++++++++++++-
 1 file changed, 113 insertions(+), 1 deletion(-)

diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py
index 8319890b..b0a97923 100644
--- a/python/dune/codegen/sumfact/transformations.py
+++ b/python/dune/codegen/sumfact/transformations.py
@@ -4,7 +4,9 @@ import loopy as lp
 import pymbolic.primitives as prim
 import islpy as isl
 
-from dune.codegen.generation import get_global_context_value
+from dune.codegen.generation import (get_counted_variable,
+                                     get_global_context_value,
+                                     )
 from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions
 from dune.codegen.options import get_form_option, get_option
 from dune.codegen.pdelab.geometry import world_dimension
@@ -170,6 +172,115 @@ def reorder_loops_in_tensor_contraction(kernel, iname_order):
     return kernel
 
 
+def reorder_loops_in_tensor_contraction_with_accumulation_variable(kernel, iname_order):
+    dim = world_dimension()
+    assert dim == 3
+
+    # kernel = remove_all_reductions(kernel)
+    kernel = reorder_loops_in_tensor_contraction(kernel, iname_order)
+
+    cond = lp.match.Tagged('set_zero')
+    for instr in lp.find_instructions(kernel, cond):
+        assert len(instr.depends_on) == 0
+
+        # Depending on this instruction
+        depending = []
+        for i in kernel.instructions:
+            if instr.id in i.depends_on:
+                depending.append(i.id)
+        assert len(depending) == 1
+
+        active_inames = [i.name.endswith('move_up') for i in instr.assignee.index]
+        from loopy.kernel.array import VectorArrayDimTag
+        agg = kernel.temporary_variables[instr.assignee.aggregate.name]
+        if isinstance(agg.dim_tags[-1], VectorArrayDimTag):
+            active_inames[-1] = True
+
+        # Instead of setting the variable itself to zero we set an accumulation
+        # variable to zero
+        kernel = lp.remove_instructions(kernel, set([instr.id]))
+        accum_variable = get_counted_variable('accum_variable')
+        accum_init_inames = tuple(i for (i,j) in zip(instr.assignee.index, active_inames) if j)
+        assignee = prim.Subscript(prim.Variable(accum_variable,), accum_init_inames)
+        accum_init_id = instr.id + '_accum_init'
+        accum_init_instr = lp.Assignment(assignee,
+                                         0,
+                                         within_inames=instr.within_inames,
+                                         id=accum_init_id,
+                                         tags=('accum_variable',)
+                                         )
+        kernel = kernel.copy(instructions=kernel.instructions + [accum_init_instr,])
+
+        # Restore dependencies
+        for dep in depending:
+            match = lp.match.Id(dep)
+            kernel = lp.add_dependency(kernel, match, accum_init_id)
+
+        # Make accumulation variable a temporary_variable of this kernel
+        #
+        # Create dim tags for accum variable
+        dim_tags = ','.join(['f'] * sum(active_inames))
+        if isinstance(agg.dim_tags[-1], VectorArrayDimTag):
+            dim_tags = ','.join(['f'] * (sum(active_inames)-1)) + ",vec"
+
+        # Create shape for accum variable
+        shape = tuple(i for (i,j) in zip(agg.shape, active_inames) if j)
+
+        from dune.codegen.loopy.temporary import DuneTemporaryVariable
+        var = {accum_variable: DuneTemporaryVariable(accum_variable,
+                                                     dtype=agg.dtype,
+                                                     shape=shape,
+                                                     dim_tags=dim_tags,
+                                                     managed=True)}
+        kernel.temporary_variables.update(var)
+
+        # Accumulate in accumulate variable
+        #
+        # Find accumulation instruction
+        accum_instr = lp.find_instructions(kernel, lp.match.Id(depending[0]))[0]
+        assert accum_instr.assignee == accum_instr.expression.children[0]
+
+        # Dependencies
+        depends_on = accum_instr.depends_on
+        depending = []
+        for i in kernel.instructions:
+            if accum_instr.id in i.depends_on:
+                depending.append(i.id)
+
+        # Replace with accumulation in accum_variable
+        kernel = lp.remove_instructions(kernel, set([accum_instr.id]))
+        accum_inames = tuple(i for (i,j) in zip(accum_instr.assignee.index, active_inames) if j)
+        assignee = prim.Subscript(prim.Variable(accum_variable,), accum_inames)
+        expression = prim.Sum((assignee, accum_instr.expression.children[1]))
+        accum_id = accum_instr.id + '_accumvar'
+        new_accum_instr = lp.Assignment(assignee,
+                                        expression,
+                                        within_inames=accum_instr.within_inames,
+                                        id=accum_id,
+                                        depends_on=depends_on,
+                                        )
+        kernel = kernel.copy(instructions=kernel.instructions + [new_accum_instr,])
+
+        # Assign accumulation result
+        #
+        # The reduction is already done
+        within_inames = frozenset(i for i in accum_instr.within_inames if 'red' not in i)
+        assign_id = accum_instr.id + '_assign'
+        assign_instr = lp.Assignment(accum_instr.assignee,
+                                     assignee,
+                                     within_inames=within_inames,
+                                     id=assign_id,
+                                     depends_on=frozenset([accum_id,]),
+                                     )
+        kernel = kernel.copy(instructions=kernel.instructions + [assign_instr,])
+
+        for dep in depending:
+            match = lp.match.Id(dep)
+            kernel = lp.add_dependency(kernel, match, assign_id)
+
+    return kernel
+
+
 def tensor_contraction_loop_order_generator(kernel):
     dim = world_dimension()
     assert dim == 3
@@ -209,6 +320,7 @@ def autotune_tensor_contraction_loop_order(kernel, signature):
 
 def sumfact_performance_transformations(kernel, signature):
     if kernel.name.startswith('sfimpl'):
+        # kernel = reorder_loops_in_tensor_contraction_with_accumulation_variable(kernel, "ljik")
         # kernel = autotune_tensor_contraction_loop_order(kernel, signature)
 
         pass
-- 
GitLab