From 5292297f29bca5bdc6d360afccf6778cd1c118d9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 18 Feb 2019 15:39:24 +0100
Subject: [PATCH] Add infrastructure for sumfact kernel transformations

The SumfactKernel object stores a tuple of transformations that will be
registered in the realization and applied in the localoperator. These
transformations can change the name of the sum factorization kernel
function. This way it should be possible to use autotuning for picking
transformations.
---
 .../transformations/remove_reductions.py      |  2 +-
 python/dune/codegen/sumfact/accumulation.py   |  3 +
 python/dune/codegen/sumfact/basis.py          |  3 +
 python/dune/codegen/sumfact/geometry.py       |  6 ++
 python/dune/codegen/sumfact/realization.py    |  7 +-
 python/dune/codegen/sumfact/symbolic.py       | 19 +++++-
 .../dune/codegen/sumfact/transformations.py   | 68 ++++++++++++++++---
 7 files changed, 95 insertions(+), 13 deletions(-)

diff --git a/python/dune/codegen/loopy/transformations/remove_reductions.py b/python/dune/codegen/loopy/transformations/remove_reductions.py
index c7848a31..f5381612 100644
--- a/python/dune/codegen/loopy/transformations/remove_reductions.py
+++ b/python/dune/codegen/loopy/transformations/remove_reductions.py
@@ -39,7 +39,7 @@ def remove_reduction(knl, match):
                                               within_inames=within_inames,
                                               id=id_accum,
                                               depends_on=frozenset((id_zero,) + tuple(depends_on)),
-                                              tags=('assignement',)))
+                                              tags=('assignment',)))
 
 
             knl = knl.copy(instructions=knl.instructions + instructions)
diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index ae4a7630..09160ea5 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -547,6 +547,9 @@ def generate_accumulation_instruction(expr, visitor):
     from dune.codegen.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    from dune.codegen.sumfact.transformations import attach_transformations
+    vsf = attach_transformations(sf, vsf)
+
     # Make sure we have a buffer that we can set up the input with
     buffer = vsf.buffer
     if buffer is None:
diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index c9b75eb4..d3a2a190 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -220,6 +220,9 @@ class SumfactBasisMixin(GenericBasisMixin):
         from dune.codegen.sumfact.vectorization import attach_vectorization_info
         vsf = attach_vectorization_info(sf)
 
+        from dune.codegen.sumfact.transformations import attach_transformations
+        vsf = attach_transformations(sf, vsf)
+
         self.indices = None
 
         # If this sum factorization kernel was not used in the dry run we
diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py
index d37558f5..689c2b60 100644
--- a/python/dune/codegen/sumfact/geometry.py
+++ b/python/dune/codegen/sumfact/geometry.py
@@ -224,6 +224,9 @@ class SumfactMultiLinearGeometryMixin(GenericPDELabGeometryMixin):
         from dune.codegen.sumfact.vectorization import attach_vectorization_info
         vsf = attach_vectorization_info(sf)
 
+        from dune.codegen.sumfact.transformations import attach_transformations
+        vsf = attach_transformations(sf, vsf)
+
         # If this sum factorization kernel was not used in the dry run we
         # just return 0
         if vsf == 0:
@@ -541,6 +544,9 @@ def _name_jacobian(i, j, restriction, visitor):
     from dune.codegen.sumfact.vectorization import attach_vectorization_info
     vsf = attach_vectorization_info(sf)
 
+    from dune.codegen.sumfact.transformations import attach_transformations
+    vsf = attach_transformations(sf, vsf)
+
     # If this sum factorization kernel was not used in the dry run we
     # just return 0
     if vsf == 0:
diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py
index 5317268f..3a628576 100644
--- a/python/dune/codegen/sumfact/realization.py
+++ b/python/dune/codegen/sumfact/realization.py
@@ -30,7 +30,6 @@ from dune.codegen.sumfact.quadrature import quadrature_points_per_direction
 from dune.codegen.sumfact.symbolic import (SumfactKernel,
                                            VectorizedSumfactKernel,
                                            )
-from dune.codegen.sumfact.vectorization import attach_vectorization_info
 from dune.codegen.sumfact.accumulation import sumfact_iname
 from dune.codegen.loopy.target import dtype_floatingpoint
 from dune.codegen.loopy.vcl import ExplicitVCLCast
@@ -279,8 +278,10 @@ def realize_sumfact_kernel_function(sf):
                                   })
 
     # Register kernel transformations
-    from dune.codegen.sumfact.transformations import reorder_loops_in_tensor_contraction
-    transform(reorder_loops_in_tensor_contraction, 'lkji')
+    for trafo in sf.transformations:
+        transform(trafo.kernel_transformation()[0],
+                  trafo.kernel_transformation()[1],
+                  **trafo.kernel_transformation()[2])
 
     # Construct a loopy kernel object
     from dune.codegen.pdelab.localoperator import extract_kernel_from_cache
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 60641b0c..3fddff62 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -473,6 +473,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
                  insn_dep=frozenset(),
                  interface=SumfactKernelInterfaceBase(),
                  predicates=frozenset(),
+                 transformations=(),
                  ):
         """Create a sum factorization kernel
 
@@ -503,6 +504,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         the transformation brings the next reduction index in the fastest
         position.
 
+        In einsum notation from numpy this can be written as three contractions
+        of the form: 'ij,jkl->kli'
+
         It can make sense to permute the order of directions. If you have
         a small m_l (e.g. stage 1 on faces) it is better to do direction l
         first. This can be done by:
@@ -533,6 +537,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
             other.
         interface: An SumfactKernelInterfaceBase instance describing the input
             (stage 1) or output (stage 3) of the kernel
+
         """
         # Assert the inputs!
         assert isinstance(matrix_sequence, tuple)
@@ -595,6 +600,10 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
             name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.interface.quadrature_permutation))
             name = name + name_quad_perm
 
+        # Change name for applied transformations
+        for t in self.transformations:
+            name = name + '_' + t.name_appendix()
+
         return name
 
     @property
@@ -828,6 +837,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
                  vertical_width=1,
                  buffer=None,
                  insn_dep=frozenset(),
+                 transformations=(),
                  ):
         # Assert the input data structure
         assert isinstance(kernels, tuple)
@@ -854,6 +864,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
                                  buffer=buffer,
                                  insn_dep=insn_dep,
                                  vertical_width=vertical_width,
+                                 transformations=transformations,
                                  )
 
         prim.Variable.__init__(self, "VecSUMFAC")
@@ -880,9 +891,15 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     #
     @property
     def function_name(self):
-        return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence_quadrature_permuted),
+        name = "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence_quadrature_permuted),
                                     self.interface.function_name_suffix)
 
+        # Change name for applied transformations
+        for t in self.transformations:
+            name = name + '_' + t.name_appendix()
+
+        return name
+
     @property
     def cache_key(self):
         """ The cache key that can be used in generation magic
diff --git a/python/dune/codegen/sumfact/transformations.py b/python/dune/codegen/sumfact/transformations.py
index 713d7fd9..76cf241c 100644
--- a/python/dune/codegen/sumfact/transformations.py
+++ b/python/dune/codegen/sumfact/transformations.py
@@ -2,10 +2,14 @@ import loopy as lp
 import pymbolic.primitives as prim
 import islpy as isl
 
+from dune.codegen.generation import get_global_context_value
 from dune.codegen.loopy.transformations.remove_reductions import remove_all_reductions
 from dune.codegen.pdelab.geometry import world_dimension
 
 def move_zero_assignment_up(knl, move_up_inames):
+    if len(move_up_inames) == 0:
+        return knl
+
     # Find the instruction we want to move around
     cond = lp.match.Tagged('set_zero')
     instructions = lp.find_instructions(knl, cond)
@@ -33,7 +37,9 @@ def move_zero_assignment_up(knl, move_up_inames):
                 # index = dom.get_var_names(isl.dim_type.set).index(iname)
 
                 # TODO: Noch unklar wie man die Loop bound aus isl rausbekommt.
-                todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4
+                todo_begin = str(dom).find(iname + ' =') + len(iname) + 3
+                if todo_begin == len(iname) + 3 - 1:
+                    todo_begin = str(dom).find(iname + ' <=') + len(iname) + 4
                 todo_end = todo_begin + str(dom)[todo_begin:].find(' ')
                 loop_bound = int(str(dom)[todo_begin:todo_end]) + 1
                 break
@@ -134,13 +140,16 @@ def reorder_loops_in_tensor_contraction(knl, iname_order):
 
         knl = move_zero_assignment_up(knl, current_move_up_inames)
 
-        # TODO
-        #
-        # Finde the number appended to the inames of this contraction by taking
-        # all the number starting from the last '_'. There is definitely a more
-        # elegant way to find that ;).
-        sf_iname_index = int(inames[0][len(inames[0]) - inames[0][::-1].find('_'):])
-        reduction_iname = 'sf_red_{}'.format(sf_iname_index)
+        # TODO: There should be a better method than searching the string for
+        # 'sf_red'. Unfortunately there are sometimes Call instructions due to
+        # broadcasts. That makes different ways difficult.
+        import re
+        regex = re.compile('sf_red_([0-9]*)')
+        reduction_index = set(regex.findall(str(instr)))
+        assert len(reduction_index) == 1
+        reduction_index = reduction_index.pop()
+        reduction_iname = 'sf_red_{}'.format(reduction_index)
+
 
         prefered_iname_order = []
         for i in inames:
@@ -153,3 +162,46 @@ def reorder_loops_in_tensor_contraction(knl, iname_order):
         knl = lp.prioritize_loops(knl, prefered_iname_order)
 
     return knl
+
+
+class SumfactKernelFunctionTransformation(object):
+    def kernel_transformation(self):
+        """Transformation that will be applied to sumfact kernel function"""
+        raise NotImplementedError
+
+    def name_appendix(self):
+        """Name will be appended to name of sumfact kernel function"""
+        raise NotImplementedError
+
+
+class LoopOrderTransformation(SumfactKernelFunctionTransformation):
+    def __init__(self, order):
+        self.order = order
+
+    def kernel_transformation(self):
+        return (reorder_loops_in_tensor_contraction, self.order, {})
+
+    def name_appendix(self):
+        return 'looporder{}'.format(self.order)
+
+
+def attach_transformations(sf, vsf):
+    if vsf == 0:
+        return 0
+
+    if get_global_context_value("dry_run") == None:
+
+        # Example and test of such a transformation:
+
+        # # Store transformation in sumfact kernel
+        # from dune.codegen.sumfact.transformations import LoopOrderTransformation
+        # trafo = LoopOrderTransformation('kjli')
+        # vsf = vsf.copy(transformations=vsf.transformations + (trafo,))
+
+        # # Map the original kernel to the new one
+        # from dune.codegen.sumfact.vectorization import _cache_vectorization_info
+        # _cache_vectorization_info(sf, vsf)
+
+        return vsf
+
+    return sf
-- 
GitLab