From a3ce64776cd153cbb023d92ee47da96f6cdbc10a Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Fri, 7 Apr 2017 11:20:22 +0200
Subject: [PATCH] Working vertical mass matrix example

---
 python/dune/perftool/sumfact/realization.py   |  6 +++---
 python/dune/perftool/sumfact/symbolic.py      | 13 ++++++++++---
 python/dune/perftool/sumfact/vectorization.py |  1 +
 test/sumfact/mass/CMakeLists.txt              |  8 ++++----
 test/sumfact/mass/mass_3d.mini                |  2 +-
 5 files changed, 19 insertions(+), 11 deletions(-)

diff --git a/python/dune/perftool/sumfact/realization.py b/python/dune/perftool/sumfact/realization.py
index df28b37c..e3ad6541 100644
--- a/python/dune/perftool/sumfact/realization.py
+++ b/python/dune/perftool/sumfact/realization.py
@@ -96,7 +96,7 @@ def _realize_sum_factorization_kernel(sf):
     if sf.vectorized:
         ftags = ftags + ",vec"
         ctags = ctags + ",vec"
-        vec_shape = (sf.horizontal_width,)
+        vec_shape = (sf.vector_width,)
 
     # Measure times and count operations in c++ code
     if get_option("instrumentation_level") >= 4:
@@ -133,7 +133,7 @@ def _realize_sum_factorization_kernel(sf):
         out_inames = tuple(sumfact_iname(length, "out_inames_" + str(k)) for k, length in enumerate(out_shape))
         vec_iname = ()
         if matrix.vectorized:
-            iname = sumfact_iname(sf.horizontal_width, "vec")
+            iname = sumfact_iname(sf.vector_width, "vec")
             vec_iname = (prim.Variable(iname),)
             transform(lp.tag_inames, [(iname, "vec")])
 
@@ -160,7 +160,7 @@ def _realize_sum_factorization_kernel(sf):
 
             globalarg(direct_input, dtype=np.float64, shape=inp_shape, dim_tags=novec_ftags)
             if matrix.vectorized:
-                input_summand = prim.Call(prim.Variable(get_vcl_typename(np.float64, vector_width=sf.horizontal_width)),
+                input_summand = prim.Call(prim.Variable(get_vcl_typename(np.float64, vector_width=sf.vector_width)),
                                           (prim.Subscript(prim.Variable(direct_input),
                                                           input_inames),))
             else:
diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py
index fb7e19df..44bdbc5f 100644
--- a/python/dune/perftool/sumfact/symbolic.py
+++ b/python/dune/perftool/sumfact/symbolic.py
@@ -356,7 +356,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
     @property
     def matrix_sequence(self):
         return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels),
-                                                width=self.horizontal_width,
+                                                width=self.vector_width,
                                                 )
                      for i in range(self.length))
 
@@ -405,6 +405,9 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
         indices = set(range(self.horizontal_width)) - set(range(len(self.kernels)))
         return tuple(self.kernels[0].quadrature_index(None) + (i,) for i in indices)
 
+    @property
+    def vector_width(self):
+        return self.horizontal_width * self.vertical_width
     #
     # Define the same properties the normal SumfactKernel defines
     #
@@ -433,7 +436,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
 
     @property
     def flat_input_shape(self):
-        return (product(mat.cols for mat in self.matrix_sequence), self.horizontal_width)
+        return (product(mat.basis_size for mat in self.matrix_sequence), self.horizontal_width)
 
     @property
     def quadrature_shape(self):
@@ -444,7 +447,11 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
             assert isinstance(direct_index, tuple)
             return self.kernels[0].quadrature_index(sf) + direct_index
         else:
-            return self.kernels[0].quadrature_index(sf) + (self.kernels.index(sf),)
+            try:
+                horizontal_index = self.kernels.index(sf)
+            except ValueError:
+                horizontal_index = 0
+            return self.kernels[0].quadrature_index(sf) + (horizontal_index,)
 
     @property
     def quadrature_dimtags(self):
diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py
index 3d984a76..1193b14a 100644
--- a/python/dune/perftool/sumfact/vectorization.py
+++ b/python/dune/perftool/sumfact/vectorization.py
@@ -70,6 +70,7 @@ def vertical_vectorization_strategy(sumfact, depth):
         vsf = VectorizedSumfactKernel(kernels=tuple(kernels),
                                       buffer=buffer,
                                       input=input,
+                                      vertical_width=depth,
                                       )
         return _cache_vectorization_info(sumfact, vsf)
     else:
diff --git a/test/sumfact/mass/CMakeLists.txt b/test/sumfact/mass/CMakeLists.txt
index b2ec50d0..a1313988 100644
--- a/test/sumfact/mass/CMakeLists.txt
+++ b/test/sumfact/mass/CMakeLists.txt
@@ -9,7 +9,7 @@ dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl
                                   INIFILE mass_3d.mini
                                   )
 
-#dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl
-#                                  BASENAME sumfact_mass_sliced
-#                                  INIFILE sliced.mini
-#                                  )
+dune_add_formcompiler_system_test(UFLFILE mass_3d.ufl
+                                  BASENAME sumfact_mass_sliced
+                                  INIFILE sliced.mini
+                                  )
diff --git a/test/sumfact/mass/mass_3d.mini b/test/sumfact/mass/mass_3d.mini
index 1244acec..77626456 100644
--- a/test/sumfact/mass/mass_3d.mini
+++ b/test/sumfact/mass/mass_3d.mini
@@ -19,4 +19,4 @@ vectorize_quad = 1, 0 | expand vec
 sumfact = 1
 
 [formcompiler.ufl_variants]
-degree = 1
+degree = 3
-- 
GitLab