From dd0363d9780f7eb3d7c917c27ec358cd8adfe660 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 22 Oct 2018 16:27:53 +0200
Subject: [PATCH] Move quadrature_permutation to interface
 SumfactKernelInterfaceBase

---
 python/dune/codegen/sumfact/accumulation.py | 15 +++++++
 python/dune/codegen/sumfact/basis.py        | 15 +++++++
 python/dune/codegen/sumfact/geometry.py     | 16 ++++++-
 python/dune/codegen/sumfact/realization.py  | 14 +++---
 python/dune/codegen/sumfact/symbolic.py     | 48 +++++++++++----------
 5 files changed, 77 insertions(+), 31 deletions(-)

diff --git a/python/dune/codegen/sumfact/accumulation.py b/python/dune/codegen/sumfact/accumulation.py
index 6af42b2f..7059b3a5 100644
--- a/python/dune/codegen/sumfact/accumulation.py
+++ b/python/dune/codegen/sumfact/accumulation.py
@@ -94,6 +94,16 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
                  trial_element=None,
                  trial_element_index=None,
                  ):
+
+        # Note: The function sumfact_quadrature_permutation_strategy does not
+        # work anymore after the visiting process since get_facedir and
+        # get_facemod are not well defined. But we need the
+        # quadrature_permutation to generate the name of the sumfact
+        # kernel. This means we need to store the value here instead of
+        # recalculating it in the property.
+        dim = world_dimension()
+        quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction[0])
+
         # TODO: Isnt accumvar superfluous in the presence of all the other infos?
         ImmutableRecord.__init__(self,
                                  accumvar=accumvar,
@@ -102,11 +112,16 @@ class AccumulationOutput(SumfactKernelInterfaceBase, ImmutableRecord):
                                  test_element_index=test_element_index,
                                  trial_element=trial_element,
                                  trial_element_index=trial_element_index,
+                                 _quadrature_permutation=quadrature_permutation,
                                  )
 
     def __repr__(self):
         return ImmutableRecord.__repr__(self)
 
+    @property
+    def quadrature_permutation(self):
+        return self._quadrature_permutation
+
     @property
     def stage(self):
         return 3
diff --git a/python/dune/codegen/sumfact/basis.py b/python/dune/codegen/sumfact/basis.py
index 0a1ce5b8..a7d2e01e 100644
--- a/python/dune/codegen/sumfact/basis.py
+++ b/python/dune/codegen/sumfact/basis.py
@@ -57,11 +57,22 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
                  element_index=0,
                  restriction=0,
                  ):
+
+        # Note: The function sumfact_quadrature_permutation_strategy does not
+        # work anymore after the visiting process since get_facedir and
+        # get_facemod are not well defined. But we need the
+        # quadrature_permutation to generate the name of the sumfact
+        # kernel. This means we need to store the value here instead of
+        # recalculating it in the property.
+        dim = world_dimension()
+        quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
+
         ImmutableRecord.__init__(self,
                                  coeff_func=coeff_func,
                                  element=element,
                                  element_index=element_index,
                                  restriction=restriction,
+                                 _quadrature_permutation=quadrature_permutation,
                                  )
 
     def __repr__(self):
@@ -70,6 +81,10 @@ class LFSSumfactKernelInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __str__(self):
         return repr(self)
 
+    @property
+    def quadrature_permutation(self):
+        return self._quadrature_permutation
+
     @property
     def stage(self):
         return 1
diff --git a/python/dune/codegen/sumfact/geometry.py b/python/dune/codegen/sumfact/geometry.py
index 3d7df550..17aa5963 100644
--- a/python/dune/codegen/sumfact/geometry.py
+++ b/python/dune/codegen/sumfact/geometry.py
@@ -67,7 +67,17 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
         argument 'direction' specifies the component (x-component: 0,
         y-component: 1, z-component: 2).
         """
-        ImmutableRecord.__init__(self, direction=direction, restriction=restriction)
+
+        # Note: The function sumfact_quadrature_permutation_strategy does not
+        # work anymore after the visiting process since get_facedir and
+        # get_facemod are not well defined. But we need the
+        # quadrature_permutation to generate the name of the sumfact
+        # kernel. This means we need to store the value here instead of
+        # recalculating it in the property.
+        dim = world_dimension()
+        quadrature_permutation = sumfact_quadrature_permutation_strategy(dim, restriction)
+
+        ImmutableRecord.__init__(self, direction=direction, restriction=restriction, _quadrature_permutation=quadrature_permutation)
 
     def __repr__(self):
         return ImmutableRecord.__repr__(self)
@@ -75,6 +85,10 @@ class GeoCornersInput(SumfactKernelInterfaceBase, ImmutableRecord):
     def __str__(self):
         return repr(self)
 
+    @property
+    def quadrature_permutation(self):
+        return self._quadrature_permutation
+
     @property
     def stage(self):
         return 1
diff --git a/python/dune/codegen/sumfact/realization.py b/python/dune/codegen/sumfact/realization.py
index ce4393ef..1d951eca 100644
--- a/python/dune/codegen/sumfact/realization.py
+++ b/python/dune/codegen/sumfact/realization.py
@@ -192,8 +192,8 @@ def realize_sumfact_kernel_function(sf):
             input_inames = permute_backward(input_inames, sf.cost_permutation)
 
             # And one more for permuted quadrature points, see comment below
-            inp_shape = permute_backward(inp_shape, sf.quadrature_permutation)
-            input_inames = permute_backward(input_inames, sf.quadrature_permutation)
+            inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation)
+            input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation)
 
             input_summand = sf.interface.realize_direct(inp_shape, input_inames)
         else:
@@ -207,10 +207,10 @@ def realize_sumfact_kernel_function(sf):
                 if sf.stage == 1:
                     # In the unstructured case the sf.matrix_sequence_quadrature_permuted could
                     # already be permuted according to
-                    # sf.quadrature_permutation. We also need to reverse this
+                    # sf.interface.quadrature_permutation. We also need to reverse this
                     # permutation to get the input from 0 to d-1.
-                    inp_shape = permute_backward(inp_shape, sf.quadrature_permutation)
-                    input_inames = permute_backward(input_inames, sf.quadrature_permutation)
+                    inp_shape = permute_backward(inp_shape, sf.interface.quadrature_permutation)
+                    input_inames = permute_backward(input_inames, sf.interface.quadrature_permutation)
 
             # Get a temporary that interprets the base storage of the input
             # as a column-major matrix. In later iteration of the matrix loop
@@ -242,7 +242,7 @@ def realize_sumfact_kernel_function(sf):
         if l == len(matrix_sequence) - 1:
             output_shape = permute_backward(output_shape, sf.cost_permutation)
             if sf.stage == 3:
-                output_shape = permute_backward(output_shape, sf.quadrature_permutation)
+                output_shape = permute_backward(output_shape, sf.interface.quadrature_permutation)
 
         out = buffer.get_temporary("buff_step{}_out".format(l),
                                    shape=output_shape + vec_shape,
@@ -263,7 +263,7 @@ def realize_sumfact_kernel_function(sf):
         if l == len(matrix_sequence) - 1:
             output_inames = permute_backward(output_inames, sf.cost_permutation)
             if sf.stage == 3:
-                output_inames = permute_backward(output_inames, sf.quadrature_permutation)
+                output_inames = permute_backward(output_inames, sf.interface.quadrature_permutation)
 
         # Collect the key word arguments for the loopy instruction
         insn_args = {"depends_on": insn_dep}
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 91e48eb2..c23d77b6 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -37,6 +37,10 @@ class SumfactKernelInterfaceBase(object):
     def realize_direct(self, *a, **kw):
         raise NotImplementedError
 
+    @property
+    def quadrature_permutation(self):
+        return ()
+
     @property
     def within_inames(self):
         return ()
@@ -73,6 +77,13 @@ class VectorSumfactKernelInput(SumfactKernelInterfaceBase):
     def __repr__(self):
         return "_".join(repr(i) for i in self.interfaces)
 
+    @property
+    def quadrature_permutation(self):
+        # TODO: For now we assure that all kerneles have the same quadrature_permutation
+        for i in self.interfaces:
+            assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
+        return self.interfaces[0].quadrature_permutation
+
     @property
     def stage(self):
         return 1
@@ -140,6 +151,13 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
     def __repr__(self):
         return "_".join(repr(o) for o in self.interfaces)
 
+    @property
+    def quadrature_permutation(self):
+        # TODO: For now we assure that all kerneles have the same quadrature_permutation
+        for i in self.interfaces:
+            assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
+        return self.interfaces[0].quadrature_permutation
+
     @property
     def stage(self):
         return 3
@@ -237,7 +255,6 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
                  insn_dep=frozenset(),
                  interface=SumfactKernelInterfaceBase(),
                  predicates=frozenset(),
-                 quadrature_permutation=None,
                  ):
         """Create a sum factorization kernel
 
@@ -311,27 +328,12 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         for a in SumfactKernel.init_arg_names:
             defaultdict[a] = eval(a)
 
-        dim = len(matrix_sequence)
-
         # Not sure if this whole permuting would make sense if we would do sum
         # factorized evaluation of intersections where len(matrix_sequence)
         # would not be equal to world dim.
+        dim = len(matrix_sequence)
         assert dim == world_dimension()
 
-        # Get restriction for this sum factorization kernel. Note: For
-        # accumulation output we have a restriction for the test (index 0) and
-        # ansatz (index 1) space. We need the restriction corresponding to the
-        # test space since we are in stage 3
-        restriction = interface.restriction
-        if isinstance(restriction, tuple):
-            assert interface.stage is 3
-            assert len(restriction) is 2
-            restriction = restriction[0]
-
-        # Store correct quadrature_permutation
-        quadrature_permuation = sumfact_quadrature_permutation_strategy(dim, restriction)
-        defaultdict['quadrature_permutation'] = quadrature_permuation
-
         # Call the base class constructors
         ImmutableRecord.__init__(self, **defaultdict)
         prim.Variable.__init__(self, "SUMFACT")
@@ -371,8 +373,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         # different permuation of quadrature points on self and neighbor. Mangle
         # the permutation of the quadrature points into the name to generate
         # sperate functions.
-        if self.quadrature_permutation != tuple(range(len(self.matrix_sequence))):
-            name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.quadrature_permutation))
+        if self.interface.quadrature_permutation != tuple(range(len(self.matrix_sequence))):
+            name_quad_perm = "_qpperm_{}".format("".join(str(a) for a in self.interface.quadrature_permutation))
             name = name + name_quad_perm
 
         return name
@@ -383,7 +385,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         # TODO: For now we do not vectorize SumfactKernels with different
         # quadrature_permutation. This should be handled like upper/lower
         # vectorization
-        return self.quadrature_permutation + tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
+        return tuple(m.basis_size for m in self.matrix_sequence_quadrature_permuted) + (self.stage, self.buffer, self.interface.within_inames)
 
     @property
     def cache_key(self):
@@ -456,7 +458,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
         to ensure that quadrature points are visited in the same order on self
         and neighbor.
         """
-        perm = self.quadrature_permutation
+        perm = self.interface.quadrature_permutation
         matrix_sequence_quadrature_permuted = permute_forward(self.matrix_sequence, perm)
         return matrix_sequence_quadrature_permuted
 
@@ -640,7 +642,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
         # Assert that quadrature permutation is the same for all kernels
         for k in kernels:
-            assert k.quadrature_permutation == kernels[0].quadrature_permutation
+            assert k.interface.quadrature_permutation == kernels[0].interface.quadrature_permutation
 
         # We currently assume that all subkernels are consecutive, 0-based within the vector
         assert None not in kernels
@@ -722,7 +724,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def quadrature_permutation(self):
-        return self.kernels[0].quadrature_permutation
+        return self.kernels[0].interface.quadrature_permutation
 
     @property
     def within_inames(self):
-- 
GitLab