From e16b2619341e1505e0746f9553be73e2ad0dc01b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 15 Nov 2018 21:07:51 +0100
Subject: [PATCH] [skip ci][wip] Move input permutation to interface classes

So far only for fastdg. This should also happen in the non-fastdg case.
---
 python/dune/codegen/sumfact/basis.py       | 14 ++++++++-
 python/dune/codegen/sumfact/realization.py | 24 +++++++---------
 python/dune/codegen/sumfact/symbolic.py    | 33 +++++++++++++++++++---
 3 files changed, 52 insertions(+), 19 deletions(-)

diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index 3ce2fb24..ccc34406 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -24,7 +24,8 @@ from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
                                              name_polynomials,
                                              polynomial_degree,
                                              )
-from dune.codegen.sumfact.permutation import (permute_forward,
+from dune.codegen.sumfact.permutation import (permute_backward,
+                                              permute_forward,
                                               sumfact_cost_permutation_strategy,
                                               sumfact_quadrature_permutation_strategy,
                                               )
@@ -144,6 +145,17 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
 
         return prim.Subscript(prim.Variable(arg), inames)
 
+    def realize_input(self, shape, inames, which=0):
+        if self.direct_is_possible:
+            shape = permute_backward(shape, self.cost_permutation)
+            shape = permute_backward(shape, self.quadrature_permutation)
+            inames = permute_backward(inames, self.cost_permutation)
+            inames = permute_backward(inames, self.quadrature_permutation)
+
+            return self.realize_direct(shape, inames)
+        else:
+            raise NotImplementedError("TODO")
+
     @property
     def function_name_suffix(self):
         if get_form_option("fastdg"):
diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py
index 1d951eca..b9d209c3 100644
--- a/python/dune/codegen/sumfact/realization.py
+++ b/python/dune/codegen/sumfact/realization.py
@@ -160,8 +160,12 @@ def realize_sumfact_kernel_function(sf):
     for l, matrix in enumerate(matrix_sequence):
         # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication
         # and get inames that realize the product.
-        inp_shape = (matrix.cols,) + tuple(mat.cols for mat in matrix_sequence[l + 1:]) + tuple(mat.rows for mat in matrix_sequence[:l])
-        out_shape = (matrix.rows,) + tuple(mat.cols for mat in matrix_sequence[l + 1:]) + tuple(mat.rows for mat in matrix_sequence[:l])
+        inp_shape = (matrix.cols,) \
+                    + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \
+                    + tuple(mat.rows for mat in matrix_sequence[:l])
+        out_shape = (matrix.rows,) \
+                    + tuple(mat.cols for mat in matrix_sequence[l + 1:]) \
+                    + tuple(mat.rows for mat in matrix_sequence[:l])
         out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape))
         vec_iname = ()
         if matrix.vectorized:
@@ -173,11 +177,11 @@ def realize_sumfact_kernel_function(sf):
         # a code generation corner case producing way too complicated code. This
         # could be fixed upstream, but the loopy code realizing reductions is not
         # trivial and the priority is kind of low.
-        if matrix.cols != 1:
+        if matrix.cols == 1:
+            k_expr = 0
+        else:
             k = sumfact_iname(matrix.cols, "red")
             k_expr = prim.Variable(k)
-        else:
-            k_expr = 0
 
         # Setup the input of the sum factorization kernel. In the
         # first matrix multiplication this can be taken from
@@ -187,15 +191,7 @@ def realize_sumfact_kernel_function(sf):
         #   (vectorized + FastDGGridOperator)
         input_inames = (k_expr,) + tuple(prim.Variable(j) for j in out_inames[1:])
         if l == 0 and sf.stage == 1 and sf.interface.direct_is_possible:
-            # One permutation for cost reduction, see comment below
-            inp_shape = permute_backward(inp_shape, sf.cost_permutation)
-            input_inames = permute_backward(input_inames, sf.cost_permutation)
-
-            # And one more for permuted quadrature points, see comment below
-            inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation)
-            input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation)
-
-            input_summand = sf.interface.realize_direct(inp_shape, input_inames)
+            input_summand = sf.interface.realize_input(inp_shape, input_inames)
         else:
             # If we did permute the order of a matrices above we also
             # permuted the order of out_inames. Unfortunately the
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 82f7e027..0cafec86 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -40,11 +40,11 @@ class SumfactKernelInterfaceBase(object):
 
     @property
     def quadrature_permutation(self):
-        return ()
+        raise NotImplementedError
 
     @property
     def cost_permutation(self):
-        return ()
+        raise NotImplementedError
 
     @property
     def combined_permutation(self):
@@ -63,12 +63,12 @@ class SumfactKernelInterfaceBase(object):
     def permute_forward_cost(self, shape, inames):
         shape = permute_forward(shape, self.cost_permutation)
         inames = permute_forward(inames, self.cost_permutation)
-        return shape_inames
+        return shape, inames
 
     def permute_forward_quadrature(self, shape, inames):
         shape = permute_forward(shape, self.quadrature_permutation)
         inames = permute_forward(inames, self.quadrature_permutation)
-        return shape_inames
+        return shape, inames
 
     @property
     def within_inames(self):
@@ -113,6 +113,20 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
             assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
         return self.interfaces[0].quadrature_permutation
 
+    @property
+    def cost_permutation(self):
+        # The cost_permutation of the underlying scalar SumfactKernel can be
+        # different for each kernel.
+
+        # TODO!
+        cost_permutation = self.interfaces[0].cost_permutation
+        for i in self.interfaces:
+            assert i.cost_permutation == cost_permutation
+
+        return cost_permutation
+
+        # raise RuntimeError("cost_permutation should not be called on VectorSumfactKernelInput")
+
     @property
     def stage(self):
         return 1
@@ -153,6 +167,17 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
             # need to load scalars into the SIMD vector.
             raise NotImplementedError("SIMD loads from scalars not implemented!")
 
+    def realize_input(self, shape, inames):
+        if self.direct_is_possible:
+            shape = permute_backward(shape, self.cost_permutation)
+            shape = permute_backward(shape, self.quadrature_permutation)
+            inames = permute_backward(inames, self.cost_permutation)
+            inames = permute_backward(inames, self.quadrature_permutation)
+
+            return self.realize_direct(shape, inames)
+        else:
+            raise NotImplementedError("TODO")
+
     @property
     def function_args(self):
         return sum((i.function_args for i in remove_duplicates(self.interfaces)), ())
-- 
GitLab