From f98117ab3d106add13773b6b51d474769eabc6b5 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Tue, 25 Oct 2016 11:02:25 +0200
Subject: [PATCH] Also do stage 3

---
 python/dune/perftool/sumfact/amatrix.py | 43 +++++++-----
 python/dune/perftool/sumfact/sumfact.py | 93 ++++++++++++++-----------
 2 files changed, 80 insertions(+), 56 deletions(-)

diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 98d3b55a..9d2e2498 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -32,17 +32,16 @@ import numpy
 
 
 class AMatrix(Record):
-    def __init__(self, a_matrix, m, n):
+    def __init__(self, a_matrix, rows, cols):
         Record.__init__(self,
                         a_matrix=a_matrix,
-                        m=m,
-                        n=n,
+                        rows=rows,
+                        cols=cols,
                         )
 
 
 class ColMajorAccess(FunctionIdentifier):
     def __init__(self, amatrix):
-        assert isinstance(amatrix, AMatrix)
         self.amatrix = amatrix
 
     def __getinitargs__(self):
@@ -50,7 +49,7 @@ class ColMajorAccess(FunctionIdentifier):
 
     @property
     def name(self):
-        return '{}.colmajoraccess'.format(self.amatrix.a_matrix)
+        return '{}.colmajoraccess'.format(self.amatrix)
 
 
 @function_mangler
@@ -192,17 +191,19 @@ def sort_quadrature_points_weights():
 
 
 @constructor_block("operator")
-def construct_theta(name):
+def construct_theta(name, transpose):
     # Make sure that the quadrature points are sorted
     sort_quadrature_points_weights()
 
-    m = name_number_of_quadrature_points_per_direction()
-    n = name_number_of_basis_functions_per_direction()
+    if transpose:
+        shape = (name_number_of_basis_functions_per_direction(), name_number_of_quadrature_points_per_direction())
+    else:
+        shape = (name_number_of_quadrature_points_per_direction(), name_number_of_basis_functions_per_direction())
     polynomials = name_polynomials()
     qp = name_oned_quadrature_points()
 
-    return ["for (std::size_t i=0; i<{}; i++){{".format(m),
-            "  for (std::size_t j=0; j<{}; j++){{".format(n),
+    return ["for (std::size_t i=0; i<{}; i++){{".format(shape[0]),
+            "  for (std::size_t j=0; j<{}; j++){{".format(shape[1]),
             "    {}.colmajoraccess(i,j) = {}.p(j,{}[i]);".format(name, polynomials, qp),
             "  }",
             "}"]
@@ -223,17 +224,25 @@ def type_theta():
 
 
 @class_member("operator")
-def define_theta(name):
+def define_theta(name, transpose):
     theta_type = type_theta()
-    number_qp = quadrature_points_per_direction()
-    number_basis = basis_functions_per_direction()
-    globalarg(name, shape=(number_basis, number_qp), dtype=numpy.float32, dim_tags="f,f")
-    initializer_list(name, [str(number_qp), str(number_basis)], classtag="operator")
-    construct_theta(name)
+    if transpose:
+        shape = (basis_functions_per_direction(), quadrature_points_per_direction())
+    else:
+        shape = (quadrature_points_per_direction(), basis_functions_per_direction())
+    globalarg(name, shape=shape, dtype=numpy.float32, dim_tags="f,f")
+    initializer_list(name, [str(axis) for axis in shape], classtag="operator")
+    construct_theta(name, transpose)
     return "{} {};".format(theta_type, name)
 
 
 def name_theta():
     name = "Theta"
-    define_theta(name)
+    define_theta(name, False)
+    return name
+
+
+def name_theta_transposed():
+    name = "ThetaT"
+    define_theta(name, True)
     return name
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index 74bd6fb9..1753589c 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -10,6 +10,13 @@ from dune.perftool.loopy.buffer import (get_buffer_temporary,
                                         switch_base_storage,
                                         )
 from dune.perftool.pdelab.spaces import name_lfs
+from dune.perftool.sumfact.amatrix import (AMatrix,
+                                               quadrature_points_per_direction,
+                                               basis_functions_per_direction,
+                                               name_theta,
+                                               name_theta_transposed,
+                                               )
+from dune.perftool.loopy.stages import stage_insn
 from pymbolic.primitives import (Call,
                                  Product,
                                  Subscript,
@@ -34,47 +41,51 @@ def sumfact_iname(bound, _type):
     return name
 
 
+def setup_theta(element, container, restriction, component, a_matrices):
+    number_basis = product(mat.cols for mat in a_matrices)
+    shape = (number_basis,)
+    inp = get_buffer_temporary("buffer",
+                               shape=shape)
+    silenced_warning('read_no_write({})'.format(inp))
+
+    # Write initial coefficients into buffer
+    basisiname = sumfact_iname(number_basis, "basis")
+    lfs = name_lfs(element, restriction, component)
+    coeff = pymbolic_coefficient(container, lfs, basisiname)
+    assignee = Subscript(Variable(inp), (Variable(basisiname),))
+    return instruction(assignee=assignee,
+                       expression=coeff,
+                       depends_on=frozenset({stage_insn(0)}),
+                       )
+
+
 # TODO this code is WIP and mainly used for experiments.
 def start_sumfactorization(element, container, restriction, component):
-    from dune.perftool.sumfact.amatrix import (AMatrix,
-                                               quadrature_points_per_direction,
-                                               basis_functions_per_direction,
-                                               name_theta,
-                                               )
-
     theta = name_theta()
-    m = quadrature_points_per_direction()
-    n = basis_functions_per_direction()
-    a_matrix = AMatrix(theta, m, n)
+    rows = quadrature_points_per_direction()
+    cols = basis_functions_per_direction()
+    a_matrix = AMatrix(theta, rows, cols)
     a_matrices = (a_matrix, a_matrix)
 
+    # Do stage 1
     initialize_buffer("buffer",
-                      base_storage_size=product(max(mat.n, mat.m) for mat in a_matrices),
+                      base_storage_size=product(max(mat.rows, mat.cols) for mat in a_matrices),
                       num=2
                       )
 
+    insn_dep = setup_theta(element, container, restriction, component, a_matrices)
 
-    number_basis = product(mat.n for mat in a_matrices)
-    shape = (n,)
-    inp = get_buffer_temporary("buffer",
-                               shape=shape)
-    silenced_warning('read_no_write({})'.format(inp))
+    sum_factorization_kernel(a_matrices, "buffer", 0, frozenset({insn_dep}))
 
-    # Write initial coefficients into buffer
-    basisiname = sumfact_iname(number_basis, "basis")
-    lfs = name_lfs(element, restriction, component)
-    coeff = pymbolic_coefficient(container, lfs, basisiname)
-    assignee = Subscript(Variable(inp), (Variable(basisiname),))
-    from dune.perftool.loopy.stages import stage_insn
-    insn_dep = instruction(assignee = assignee,
-                           expression = coeff,
-                           depends_on = frozenset({stage_insn(0)}),
-                           )
+    # Do stage 3 (for f=u => mass matrix)
+    theta_transposed = name_theta_transposed()
+    a_matrix_transposed = AMatrix(theta_transposed, cols, rows)
+    a_matrices_transposed = (a_matrix_transposed, a_matrix_transposed)
 
-    return sum_factorization_kernel(a_matrices, inp, "buffer", insn_dep)
+    return sum_factorization_kernel(a_matrices_transposed, "buffer", 2)
 
 
-def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
+def sum_factorization_kernel(a_matrices, buffer, stage, insn_dep=frozenset({})):
     """
     Calculate a sum factorization matrix product.
 
@@ -87,15 +98,18 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
     a_matrices: An iterable of AMatrix instances
         The list of tensors to be applied to the input.
         Order of application is from 0 up.
-    inp: A temporary that contains the input matrix
     buffer: A string identifying the flip flop buffer in use
-        for intermediate results.
+        for intermediate results. The memory is expected to be
+        pre-initialized with the input.
+    insn_dep: an instruction ID that the first issued instruction
+        should depend upon. All following ones will depend on each
+        other.
     """
     for l, a_matrix in enumerate(a_matrices):
         # Get a temporary that interprets the base storage of the input
         # as a column-major matrix. In later iteration of the amatrix loop
         # this reinterprets the output of the previous iteration.
-        inp_shape = (a_matrix.n, product(mat.m for mat in a_matrices[:l]) * product(mat.n for mat in a_matrices[l + 1:]))
+        inp_shape = (a_matrix.cols, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
         inp = get_buffer_temporary(buffer,
                                    shape=inp_shape,
                                    dim_tags="f,f")
@@ -107,7 +121,7 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
 
         # Get a temporary that interprets the base storage of the output
         # as row-major matrix
-        out_shape = (a_matrix.m, product(mat.m for mat in a_matrices[:l]) * product(mat.n for mat in a_matrices[l + 1:]))
+        out_shape = (a_matrix.rows, product(mat.rows for mat in a_matrices[:l]) * product(mat.cols for mat in a_matrices[l + 1:]))
         out = get_buffer_temporary(buffer,
                                    shape=out_shape,
                                    dim_tags="c,c")
@@ -115,21 +129,22 @@ def sum_factorization_kernel(a_matrices, inp, buffer, insn_dep):
         # Get the inames needed for one matrix-matrix multiplication
         i = sumfact_iname(out_shape[0], "row")
         j = sumfact_iname(out_shape[1], "col")
-        k = sumfact_iname(a_matrix.n, "red")
+        k = sumfact_iname(a_matrix.cols, "red")
 
         # Construct the matrix-matrix-multiplication expression a_ik*in_kj
         from dune.perftool.sumfact.amatrix import ColMajorAccess
-        prod = Product((Call(ColMajorAccess(a_matrix), (Variable(i), Variable(k))),
+        prod = Product((Call(ColMajorAccess(a_matrix.a_matrix), (Variable(i), Variable(k))),
                         Subscript(Variable(inp), (Variable(k), Variable(j)))
                         ))
 
         # Issue the reduction instruction that implements the multiplication
         # at the same time store the instruction ID for the next instruction to depend on
-        insn_dep = instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
-                               expression=Reduction("sum", k, prod),
-                               forced_iname_deps=frozenset({i, j}),
-                               forced_iname_deps_is_final=True,
-                               depends_on=frozenset({insn_dep}),
-                               )
+        insn_dep = frozenset({instruction(assignee=Subscript(Variable(out), (Variable(i), Variable(j))),
+                                          expression=Reduction("sum", k, prod),
+                                          forced_iname_deps=frozenset({i, j}),
+                                          forced_iname_deps_is_final=True,
+                                          depends_on=insn_dep.union(frozenset({stage_insn(stage)})),
+                                          )
+                              })
 
     return out
-- 
GitLab