From d920c3e1651309703ecf46e3e2f119cfa841d81a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ren=C3=A9=20He=C3=9F?= <rene.hess@iwr.uni-heidelberg.de>
Date: Mon, 17 Dec 2018 11:36:37 +0100
Subject: [PATCH] [skip ci][WIP] Code cleanup

---
 python/dune/codegen/sumfact/autotune.py |  6 +--
 python/dune/codegen/sumfact/symbolic.py | 64 ++++++++++++++++++-------
 2 files changed, 49 insertions(+), 21 deletions(-)

diff --git a/python/dune/codegen/sumfact/autotune.py b/python/dune/codegen/sumfact/autotune.py
index 48b281d1..4b9a1d13 100644
--- a/python/dune/codegen/sumfact/autotune.py
+++ b/python/dune/codegen/sumfact/autotune.py
@@ -93,7 +93,7 @@ def generate_standalone_code(sf, filename):
         f.write("  using DF = {};\n".format(real))
 
         from dune.codegen.sumfact.tabulation import name_polynomials
-        degs = tuple(m.basis_size - 1 for m in sf.matrix_sequence)
+        degs = tuple(m.basis_size - 1 for m in sf.matrix_sequence_quadrature_permuted)
         for deg in set(degs):
             f.write("  Dune::QkStuff::EquidistantLagrangePolynomials<DF, RF, {}> {};\n".format(deg, name_polynomials(deg)))
 
@@ -105,8 +105,8 @@ def generate_standalone_code(sf, filename):
         constructor_knl = lp.get_one_scheduled_kernel(constructor_knl)
 
         # Allocate buffers
-        size = max(product(m.quadrature_size for m in sf.matrix_sequence) * sf.vector_width,
-                   product(m.basis_size for m in sf.matrix_sequence) * sf.vector_width)
+        size = max(product(m.quadrature_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width,
+                   product(m.basis_size for m in sf.matrix_sequence_quadrature_permuted) * sf.vector_width)
         size = int(size * (get_option("precision_bits") / 8))
         f.writelines(["  char buffer0[{}] __attribute__ ((aligned (32)));\n".format(size),
                       "  char buffer1[{}] __attribute__ ((aligned (32)));\n".format(size),
diff --git a/python/dune/codegen/sumfact/symbolic.py b/python/dune/codegen/sumfact/symbolic.py
index 8015da4c..45cd7548 100644
--- a/python/dune/codegen/sumfact/symbolic.py
+++ b/python/dune/codegen/sumfact/symbolic.py
@@ -303,7 +303,12 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
 
     @property
     def quadrature_permutation(self):
-        # TODO: For now we assure that all kerneles have the same quadrature_permutation
+        # TODO: This should be turned into an error.
+        #
+        #
+        # The quadrature permutation could be different for different scalar kernels!
+        # raise RuntimeError('quadrature_permutation should not be called on VectorSumfactKernelOutput')
+        #
         for i in self.interfaces:
             assert i.quadrature_permutation == self.interfaces[0].quadrature_permutation
         return self.interfaces[0].quadrature_permutation
@@ -328,13 +333,15 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
         return prim.Call(prim.Variable(hadd_function), (result,))
 
     def realize_input(self, inames, shape, vec_iname, vec_shape, buffer, ftags, l):
-        # TODO: Include permutations of scalar kernels as soon as they could be different
+        # The input for stage 3 is quadrature permuted. The inames and shape
+        # passed to this method are quadrature and cost permuted. This means we
+        # need to take the cost permutation back to get the right inames and
+        # shape for interpreting the input!
         shape = permute_backward(shape, self.cost_permutation)
         inames = permute_backward(inames, self.cost_permutation)
 
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the matrix loop
-        # this reinterprets the output of the previous iteration.
+        # Get a temporary that interprets the base storage of the input as a
+        # column-major matrix.
         inp = buffer.get_temporary("buff_step{}_in".format(l),
                                    shape=shape + vec_shape,
                                    dim_tags=ftags,
@@ -370,16 +377,26 @@ class VectorSumfactKernelOutput(SumfactKernelInterfaceBase):
         return deps
 
     def accumulate_output(self, sf, result, insn_dep):
-        # TODO: vector_cost_permutation not used!
-
         outputs = set(self.interfaces)
 
+        # Note: Using matrix_sequence_quadrature_permuted is ok in this place since:
+        #
+        # - If the grid is unstructured we assume that the polynomial degree
+        #   for each direction is the same.
+        #
+        # - If the grid is structured the quadrature permuted matrix sequence
+        #   is the same as the original one.  We still need to call this one
+        #   since VectorizedSumfactKernels do not have the matrix_sequence
+        #   attribute.
+        basis_size = tuple(mat.basis_size for mat in sf.matrix_sequence_quadrature_permuted)
+        if get_option('grid_unstructured'):
+            assert len(set(basis_size)) == 1
+
         trial_element, = set(o.trial_element for o in self.interfaces)
         trial_element_index = set(o.trial_element_index for o in self.interfaces).pop()
         from dune.codegen.sumfact.accumulation import accum_iname
         element = get_leaf(trial_element, trial_element_index) if trial_element is not None else None
-        inames = tuple(accum_iname(element, mat.rows, i)
-                       for i, mat in enumerate(sf.matrix_sequence_quadrature_permuted))
+        inames = tuple(accum_iname(element, size, i) for i, size in enumerate(basis_size))
         veciname = accum_iname(element, sf.vector_width // len(outputs), "vec")
         transform(lp.tag_inames, [(veciname, "vec")])
 
@@ -807,9 +824,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
         assert len(set(k.within_inames for k in kernels)) == 1
         assert len(set(k.predicates for k in kernels)) == 1
 
-        # Assert properties of the matrix sequence of the underlying kernels
         # For now we don't mix direct and non_direct input. Could be done in an upper/lower way.
         assert len(set(tuple(k.interface.direct_is_possible for k in kernels))) == 1
+
+        # Assert properties of the matrix sequence of the underlying kernels
         for i in range(kernels[0].length):
             assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].rows for k in kernels))) == 1
             assert len(set(tuple(k.matrix_sequence_quadrature_permuted[i].cols for k in kernels))) == 1
@@ -819,9 +837,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
         # Join the instruction dependencies of all subkernels
         insn_dep = insn_dep.union(k.insn_dep for k in kernels)
 
-        # Assert that quadrature permutation is the same for all kernels
+        # Assert that the cost_permutation is the same for all kernels
         for k in kernels:
-            assert k.interface.quadrature_permutation == kernels[0].interface.quadrature_permutation
+            assert k.interface.cost_permutation == kernels[0].interface.cost_permutation
 
         # We currently assume that all subkernels are consecutive, 0-based within the vector
         assert None not in kernels
@@ -881,7 +899,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def matrix_sequence_quadrature_permuted(self):
-        # TODO: This should be turned into a RuntimeError
+        # Construct quadrature permuted matrix sequence from scalar case
         return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_quadrature_permuted[i] for k in self.kernels),
                                                 width=self.vector_width,
                                                 )
@@ -889,13 +907,20 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def matrix_sequence_cost_permuted(self):
-        perm = sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
-        matrix_sequence_cost_permuted = permute_forward(self.matrix_sequence_quadrature_permuted, perm)
-        return matrix_sequence_cost_permuted
+        # Construct cost permuted matrix sequence from scalar case
+        matrix_sequence = tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_cost_permuted[i] for k in self.kernels),
+                                                           width=self.vector_width,)
+                                for i in range(self.length))
+
+        # This should already be cost optimal
+        perm = sumfact_cost_permutation_strategy(matrix_sequence, self.stage)
+        assert perm == tuple(i for i in range(len(perm)))
+
+        return matrix_sequence
 
     @property
     def cost_permutation(self):
-        return sumfact_cost_permutation_strategy(self.matrix_sequence_quadrature_permuted, self.stage)
+        return self.kernels[0].cost_permutation
 
     @property
     def stage(self):
@@ -907,7 +932,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def quadrature_permutation(self):
-        return self.kernels[0].interface.quadrature_permutation
+        # The quadrature_permutations of the underlying scalar kernels can be
+        # different from kernel to kernel. So there is no well defined
+        # quadrature_permutation on the VectorizedSumfactKernel.
+        raise RuntimeError("quadrature_permutation should not be used on VectorizedSumfactKernel.")
 
     @property
     def within_inames(self):
-- 
GitLab