From 73abd9a0726a28f6195192b081840607e677e66c Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 20 Dec 2016 13:58:17 +0100
Subject: [PATCH] Finalize broadcasting of input coefficients in vectorized
 case

---
 python/dune/perftool/sumfact/basis.py   |  6 +-
 python/dune/perftool/sumfact/sumfact.py | 86 ++++++++++++++-----------
 2 files changed, 49 insertions(+), 43 deletions(-)

diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index c9ba31b6..d3452f31 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -93,8 +93,7 @@ def pymbolic_trialfunction_gradient(element, restriction, component, visitor):
         if insn_dep is None:
             insn_dep = frozenset({Writes(inp)})
 
-        # TODO: fastdg and vectorization
-        if get_option('fastdg') and index is None:
+        if get_option('fastdg'):
             # Name of direct input, shape and globalarg is set in sum_factorization_kernel
             direct_input = name_coefficientcontainer(restriction)
         else:
@@ -163,10 +162,9 @@ def pymbolic_trialfunction(element, restriction, component, visitor):
                                       )
 
     # TODO: fastdg and vectorization
-    if get_option('fastdg') and index is not None:
+    if get_option('fastdg'):
         # Name of direct input, shape and globalarg is set in sum_factorization_kernel
         direct_input = name_coefficientcontainer(restriction)
-        setup_theta(inp, element, restriction, component, index)
     else:
         direct_input = None
         # Setup the input!
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index ef9cf9a5..e135fb75 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -357,62 +357,70 @@ def sum_factorization_kernel(a_matrices,
                                   )})
 
     for l, a_matrix in enumerate(a_matrices):
-        # Get a temporary that interprets the base storage of the input
-        # as a column-major matrix. In later iteration of the amatrix loop
-        # this reinterprets the output of the previous iteration.
+        # Compute the correct shapes of in- and output matrices of this matrix-matrix multiplication
+        # and get inames that realize the product.
         inp_shape = (a_matrix.cols, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
+        out_shape = (a_matrix.rows, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
+        i = sumfact_iname(out_shape[0], "row")
+        j = sumfact_iname(out_shape[1], "col")
+        vec_iname = ()
+        if a_matrix.vectorized:
+            iname = sumfact_iname(4, "vec")
+            vec_iname = (prim.Variable(iname),)
+            transform(lp.tag_inames, [(iname, "vec")])
 
-        # TODO: fastdg and vectorization
-        if l==0 and direct_input is not None and not a_matrix.vectorized:
+        # A trivial reduction is implemented as a product, otherwise we run into
+        # a code generation corner case producing way too complicated code. This
+        # could be fixed upstream, but the loopy code realizing reductions is not
+        # trivial and the priority is kind of low.
+        if a_matrix.cols != 1:
+            k = sumfact_iname(a_matrix.cols, "red")
+            k_expr = prim.Variable(k)
+        else:
+            k_expr = 0
+
+        # Setup the input of the sum factorization kernel. In the first matrix multiplication
+        # this can be taken from
+        # * an input temporary (default)
+        # * a global data structure (if FastDGGridOperator is in use)
+        # * a value from a global data structure, broadcasted to a vector type (vectorized + FastDGGridOperator)
+        if l==0 and direct_input is not None:
             globalarg(direct_input, dtype=np.float64, shape=inp_shape)
-            inp = direct_input
+            if a_matrix.vectorized:
+                input_summand = prim.Call(prim.Variable("Vec4d"), (prim.Subscript(prim.Variable(direct_input),
+                                                                                  (k_expr, prim.Variable(j))),))
+            else:
+                input_summand = prim.Subscript(prim.Variable(direct_input),
+                                               (k_expr, prim.Variable(j)) + vec_iname)
         else:
+            # Get a temporary that interprets the base storage of the input
+            # as a column-major matrix. In later iteration of the amatrix loop
+            # this reinterprets the output of the previous iteration.
             inp = get_buffer_temporary(buf,
                                        shape=inp_shape + vec_shape,
                                        dim_tags=ftags)
 
-        # The input temporary will only be read from, so we need to silence the loopy warning
-        silenced_warning('read_no_write({})'.format(inp))
+            # The input temporary will only be read from, so we need to silence the loopy warning
+            silenced_warning('read_no_write({})'.format(inp))
+
+            input_summand = prim.Subscript(prim.Variable(inp),
+                                           (k_expr, prim.Variable(j)) + vec_iname)
 
         switch_base_storage(buf)
 
         # Get a temporary that interprets the base storage of the output
         # as row-major matrix
-        out_shape = (a_matrix.rows, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
-
         out = get_buffer_temporary(buf,
                                    shape=out_shape + vec_shape,
                                    dim_tags=ctags)
 
-        # Get the inames needed for one matrix-matrix multiplication
-        i = sumfact_iname(out_shape[0], "row")
-        j = sumfact_iname(out_shape[1], "col")
-
-        # Maybe introduce a vectorization iname for this matrix-matrix multiplication
-        vec_iname = ()
-        if a_matrix.vectorized:
-            iname = sumfact_iname(4, "vec")
-            vec_iname = (prim.Variable(iname),)
-            transform(lp.tag_inames, [(iname, "vec")])
-
-        if a_matrix.cols == 1:
-            # A trivial reduction is implemented as a product, otherwise we run into
-            # a code generation corner case producing way too complicated code. This
-            # could be fixed upstream, but the loopy code realizing reductions is not
-            # trivial and the priority is kind of low.
-            matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), (prim.Variable(i), 0) + vec_iname),
-                               prim.Subscript(prim.Variable(inp), (0, prim.Variable(j)) + vec_iname)
-                               ))
-        else:
-            k = sumfact_iname(a_matrix.cols, "red")
-
-            # Construct the matrix-matrix-multiplication expression a_ik*in_kj
-            prod = prim.Product((prim.Subscript(prim.Variable(a_matrix.name),
-                                                (prim.Variable(i), prim.Variable(k)) + vec_iname),
-                                 prim.Subscript(prim.Variable(inp),
-                                                (prim.Variable(k), prim.Variable(j)) + vec_iname)
-                                ))
-            matprod = lp.Reduction("sum", k, prod)
+        # Write the matrix-matrix multiplication expression
+        matprod = Product((prim.Subscript(prim.Variable(a_matrix.name), (prim.Variable(i), k_expr) + vec_iname),
+                           input_summand
+                           ))
+        if a_matrix.cols != 1:
+            # ... which may be a reduction, if k>0
+            matprod = lp.Reduction("sum", k, matprod)
 
         # Issue the reduction instruction that implements the multiplication
         # at the same time store the instruction ID for the next instruction to depend on
-- 
GitLab