From 327658d5b7196270bd875167f6c5369484ca8f6d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Thu, 15 Nov 2018 14:13:53 +0100
Subject: [PATCH] Add permutation methods to Interface classes

Note: They are not yet used but in the long term the permutation should be
handled here since it is about input/output setup.
---
 python/dune/codegen/sumfact/accumulation.py | 18 ++++++++--
 python/dune/codegen/sumfact/basis.py        | 20 +++++++++--
 python/dune/codegen/sumfact/geometry.py     | 25 +++++++++++---
 python/dune/codegen/sumfact/permutation.py  |  6 +---
 python/dune/codegen/sumfact/symbolic.py     | 37 ++++++++++++++++++---
 5 files changed, 87 insertions(+), 19 deletions(-)

diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index 44d6ffbd..e4512dc1 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -30,7 +30,10 @@ from dune.codegen.pdelab.restriction import restricted_name
 from dune.codegen.pdelab.signatures import assembler_routine_name
 from dune.codegen.pdelab.geometry import world_dimension
 from dune.codegen.pdelab.spaces import name_lfs
-from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy
+from dune.codegen.sumfact.permutation import (permute_forward,
+                                              sumfact_cost_permutation_strategy,
+                                              sumfact_quadrature_permutation_strategy,
+                                              )
 from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
                                              construct_basis_matrix_sequence,
                                              )
@@ -88,6 +91,7 @@ def accum_iname(element, bound, i):
 
 class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __init__(self,
+                 matrix_sequence,
                  accumvar=None,
                  restriction=None,
                  test_element=None,
@@ -105,6 +109,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
         dim = world_dimension()
         quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0])
 
+        # Calculate cost optimal permutation
+        matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
+        cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
+
         # TODO: Isnt accumvar superfluous in the presence of all the other infos?
         ImmutableRecord.__init__(self,
                                  accumvar=accumvar,
@@ -114,6 +122,7 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
                                  trial_element=trial_element,
                                  trial_element_index=trial_element_index,
                                  _quadrature_permutation=quadrature_permutation,
+                                 _cost_permutation=cost_permutation,
                                  )
 
     def __repr__(self):
@@ -123,6 +132,10 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
     def quadrature_permutation(self):
         return self._quadrature_permutation
 
+    @property
+    def cost_permutation(self):
+        return self._cost_permutation
+
     @property
     def stage(self):
         return 3
@@ -457,7 +470,8 @@ def generate_accumulation_instruction(expr, visitor):
     if priority is None:
         priority = 3
 
-    output = AccumulationOutput(accumvar=accumvar,
+    output = AccumulationOutput(matrix_sequence,
+                                accumvar=accumvar,
                                 restriction=(test_info.restriction, trial_info.restriction),
                                 test_element=test_info.element,
                                 test_element_index=test_info.element_index,
diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index eb5888b9..3ce2fb24 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -24,7 +24,10 @@ from dune.codegen.sumfact.tabulation import (basis_functions_per_direction,
                                              name_polynomials,
                                              polynomial_degree,
                                              )
-from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy
+from dune.codegen.sumfact.permutation import (permute_forward,
+                                              sumfact_cost_permutation_strategy,
+                                              sumfact_quadrature_permutation_strategy,
+                                              )
 from dune.codegen.sumfact.quadrature import quadrature_inames
 from dune.codegen.sumfact.switch import (get_facedir,
                                          get_facemod,
@@ -53,6 +56,7 @@ import pymbolic.primitives as prim
 
 class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __init__(self,
+                 matrix_sequence,
                  coeff_func=None,
                  element=None,
                  element_index=0,
@@ -68,12 +72,16 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
         dim = world_dimension()
         quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
 
+        matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
+        cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
+
         ImmutableRecord.__init__(self,
                                  coeff_func=coeff_func,
                                  element=element,
                                  element_index=element_index,
                                  restriction=restriction,
                                  _quadrature_permutation=quadrature_permutation,
+                                 _cost_permutation=cost_permutation,
                                  )
 
     def __repr__(self):
@@ -86,6 +94,10 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def quadrature_permutation(self):
         return self._quadrature_permutation
 
+    @property
+    def cost_permutation(self):
+        return self._cost_permutation
+
     @property
     def stage(self):
         return 1
@@ -196,7 +208,8 @@ def pymbolic_coefficient_gradient(element, restriction, index, coeff_func, visit
                                                       basis_size=basis_size,
                                                       )
 
-    inp = LFSSumfactKernelInput(coeff_func=coeff_func,
+    inp = LFSSumfactKernelInput(matrix_sequence,
+                                coeff_func=coeff_func,
                                 element=element,
                                 element_index=index,
                                 restriction=restriction,
@@ -239,7 +252,8 @@ def pymbolic_coefficient(element, restriction, index, coeff_func, visitor):
                                                       facemod=get_facemod(restriction),
                                                       basis_size=basis_size)
 
-    inp = LFSSumfactKernelInput(coeff_func=coeff_func,
+    inp = LFSSumfactKernelInput(matrix_sequence,
+                                coeff_func=coeff_func,
                                 element=element,
                                 element_index=index,
                                 restriction=restriction,
diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py
index f8ce62b3..ce2922a8 100644
--- a/python/dune/codegen/sumfact/geometry.py
+++ b/python/dune/codegen/sumfact/geometry.py
@@ -28,7 +28,10 @@ from dune.codegen.pdelab.localoperator import (name_ansatz_gfs_constructor_param
 from dune.codegen.pdelab.restriction import restricted_name
 from dune.codegen.sumfact.accumulation import basis_sf_kernels
 from dune.codegen.sumfact.basis import construct_basis_matrix_sequence
-from dune.codegen.sumfact.permutation import sumfact_quadrature_permutation_strategy
+from dune.codegen.sumfact.permutation import (permute_forward,
+                                              sumfact_cost_permutation_strategy,
+                                              sumfact_quadrature_permutation_strategy,
+                                              )
 from dune.codegen.sumfact.quadrature import (additional_inames,
                                              default_quadrature_inames)
 from dune.codegen.sumfact.realization import (name_buffer_storage,
@@ -57,7 +60,7 @@ def global_corner_iname(restriction):
 
 
 class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
-    def __init__(self, direction, restriction):
+    def __init__(self, matrix_sequence, direction, restriction):
         """Base class for sum-factorized evaluation of geometry mappings
 
         At the moment we only do this for cells and not faces. For
@@ -78,7 +81,15 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
         dim = world_dimension()
         quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
 
-        ImmutableRecord.__init__(self, direction=direction, restriction=restriction, _quadrature_permutation=quadrature_permutation)
+        matrix_sequence = permute_forward(matrix_sequence, quadrature_permutation)
+        cost_permutation = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
+
+        ImmutableRecord.__init__(self,
+                                 direction=direction,
+                                 restriction=restriction,
+                                 _quadrature_permutation=quadrature_permutation,
+                                 _cost_permutation=cost_permutation,
+                                 )
 
     def __repr__(self):
         return ImmutableRecord.__repr__(self)
@@ -90,6 +101,10 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def quadrature_permutation(self):
         return self._quadrature_permutation
 
+    @property
+    def cost_permutation(self):
+        return self._cost_permutation
+
     @property
     def stage(self):
         return 1
@@ -160,7 +175,7 @@ def pymbolic_spatial_coordinate_multilinear(do_predicates, visitor):
     matrix_sequence = construct_basis_matrix_sequence(facedir=get_facedir(restriction),
                                                       facemod=get_facemod(restriction),
                                                       basis_size=(2,) * world_dimension())
-    inp = GeoCornersInput(visitor.indices[0], restriction)
+    inp = GeoCornersInput(matrix_sequence, visitor.indices[0], restriction)
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
                        interface=inp,
                        )
@@ -537,7 +552,7 @@ def _name_jacobian(i, j, restriction, visitor):
                                                       basis_size=(2,) * world_dimension())
 
     # Sum factorization input for the i'th component of the geometry mapping
-    inp = GeoCornersInput(i, restriction)
+    inp = GeoCornersInput(matrix_sequence, i, restriction)
 
     sf = SumfactKernel(matrix_sequence=matrix_sequence,
                        interface=inp,
diff --git a/python/dune/codegen/sumfact/permutation.py b/python/dune/codegen/sumfact/permutation.py
index c2c71919..0c98b368 100644
--- a/python/dune/codegen/sumfact/permutation.py
+++ b/python/dune/codegen/sumfact/permutation.py
@@ -45,16 +45,12 @@ def flop_cost(matrix_sequence):
     return 2 * cost
 
 
-def sumfact_cost_permutation_strategy(sf):
+def sumfact_cost_permutation_strategy(matrix_sequence, stage):
     """Choose permutation of the matrix sequence based on computational cost
 
     Note: If there are multiple permutations with the same cost a
     heuristic is used to pick one.
     """
-    # Extract information from the SumfactKernel object
-    matrix_sequence = sf.matrix_sequence_quadrature_permuted
-    stage = sf.stage
-
     # Combine permutation and matrix_sequence
     perm = [i for i, _ in enumerate(matrix_sequence)]
     perm_matrix_sequence = zip(perm, matrix_sequence)
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index a0868a5c..82f7e027 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -7,6 +7,7 @@ from dune.codegen.generation import (get_counted_variable,
                                      )
 from dune.codegen.pdelab.geometry import local_dimension, world_dimension
 from dune.codegen.sumfact.permutation import (flop_cost,
+                                              permute_backward,
                                               permute_forward,
                                               sumfact_cost_permutation_strategy,
                                               sumfact_quadrature_permutation_strategy,
@@ -41,6 +42,34 @@ class SumfactKernelInterfaceBase(object):
     def quadrature_permutation(self):
         return ()
 
+    @property
+    def cost_permutation(self):
+        return ()
+
+    @property
+    def combined_permutation(self):
+        return permute_forward(self.quadrature_permutation, self.cost_permutation)
+
+    def permute_backward_cost(self, shape, inames):
+        shape = permute_backward(shape, self.cost_permutation)
+        inames = permute_backward(inames, self.cost_permutation)
+        return shape, inames
+
+    def permute_backward_quadrature(self, shape, inames):
+        shape = permute_backward(shape, self.quadrature_permutation)
+        inames = permute_backward(inames, self.quadrature_permutation)
+        return shape, inames
+
+    def permute_forward_cost(self, shape, inames):
+        shape = permute_forward(shape, self.cost_permutation)
+        inames = permute_forward(inames, self.cost_permutation)
+        return shape_inames
+
+    def permute_forward_quadrature(self, shape, inames):
+        shape = permute_forward(shape, self.quadrature_permutation)
+        inames = permute_forward(inames, self.quadrature_permutation)
+        return shape_inames
+
     @property
     def within_inames(self):
         return ()
@@ -474,13 +503,13 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
 
         Rule of thumb: small m's early and large n's late.
         """
-        perm = sumfact_cost_permutation_strategy(self)
+        perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
         matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm)
         return matrix_sequence_cost_permuted
 
     @property
     def cost_permutation(self):
-        return sumfact_cost_permutation_strategy(self)
+        return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
 
     @property
     def quadrature_shape(self):
@@ -713,13 +742,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def matrix_sequence_cost_permuted(self):
-        perm = sumfact_cost_permutation_strategy(self)
+        perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
         matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm)
         return matrix_sequence_cost_permuted
 
     @property
     def cost_permutation(self):
-        return sumfact_cost_permutation_strategy(self)
+        return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
 
     @property
     def stage(self):
-- 
GitLab