From b58d4fd03a81aa47cde0b7f73236871c12c35a47 Mon Sep 17 00:00:00 2001
From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de>
Date: Mon, 28 Nov 2016 17:49:47 +0100
Subject: [PATCH] Simplify theta matrix generation

---
 python/dune/perftool/sumfact/amatrix.py | 44 +++++++------------------
 python/dune/perftool/sumfact/basis.py   | 12 +++----
 python/dune/perftool/sumfact/sumfact.py | 10 ++----
 3 files changed, 19 insertions(+), 47 deletions(-)

diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py
index 1c786393..41fe8ed7 100644
--- a/python/dune/perftool/sumfact/amatrix.py
+++ b/python/dune/perftool/sumfact/amatrix.py
@@ -35,15 +35,16 @@ import numpy
 
 
 class AMatrix(Record):
-    def __init__(self, a_matrix, rows, cols):
+    def __init__(self, rows, cols, transpose=False, derivative=False):
         Record.__init__(self,
-                        a_matrix=a_matrix,
                         rows=rows,
                         cols=cols,
+                        transpose=False,
+                        derivative=False,
                         )
 
     def __hash__(self):
-        return hash((self.a_matrix, self.rows, self.cols))
+        return hash((self.transpose, self.derivative, self.rows, self.cols))
 
 
 def quadrature_points_per_direction():
@@ -199,33 +200,12 @@ def define_theta(name, shape, transpose, derivative):
                 )
 
 
-def name_theta():
-    name = "Theta"
-    shape = (quadrature_points_per_direction(), basis_functions_per_direction())
-    globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
-    define_theta(name, shape, False, False)
-    return name
-
-
-def name_theta_transposed():
-    name = "ThetaT"
-    shape = (basis_functions_per_direction(), quadrature_points_per_direction())
-    globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
-    define_theta(name, shape, True, False)
-    return name
-
-
-def name_dtheta():
-    name = "dTheta"
-    shape = (quadrature_points_per_direction(), basis_functions_per_direction())
-    globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
-    define_theta(name, shape, False, True)
-    return name
-
-
-def name_dtheta_transposed():
-    name = "dThetaT"
-    shape = (basis_functions_per_direction(), quadrature_points_per_direction())
+def name_theta(transpose=False, derivative=False):
+    name = "{}Theta{}".format("d" if derivative else "", "T" if transpose else "")
+    if transpose:
+        shape = (basis_functions_per_direction(), quadrature_points_per_direction())
+    else:
+        shape = (quadrature_points_per_direction(), basis_functions_per_direction())
     globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
-    define_theta(name, shape, True, True)
-    return name
+    define_theta(name, shape, transpose, derivative)
+    return name
\ No newline at end of file
diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py
index 10021fd5..96f24219 100644
--- a/python/dune/perftool/sumfact/basis.py
+++ b/python/dune/perftool/sumfact/basis.py
@@ -15,7 +15,6 @@ from dune.perftool.generation import (backend,
                                       )
 from dune.perftool.sumfact.amatrix import (AMatrix,
                                            basis_functions_per_direction,
-                                           name_dtheta,
                                            name_theta,
                                            quadrature_points_per_direction,
                                            )
@@ -50,12 +49,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
     temporary_variable(name, shape=shape, shape_impl=shape_impl)
 
     # Calculate values with sumfactorization
-    theta = name_theta()
-    dtheta = name_dtheta()
     rows = quadrature_points_per_direction()
     cols = basis_functions_per_direction()
-    theta_matrix = AMatrix(theta, rows, cols)
-    dtheta_matrix = AMatrix(dtheta, rows, cols)
+    theta_matrix = AMatrix(rows, cols)
+    dtheta_matrix = AMatrix(rows, cols, derivative=True)
 
     # TODO:
     # - This only covers rank 1
@@ -117,10 +114,9 @@ def pymbolic_trialfunction(element, restriction, component):
     dim = formdata.geometric_dimension
 
     # Setup sumfactorization
-    theta = name_theta()
     rows = quadrature_points_per_direction()
     cols = basis_functions_per_direction()
-    a_matrix = AMatrix(theta, rows, cols)
+    a_matrix = AMatrix(rows, cols)
     a_matrices = (a_matrix,) * dim
 
     # Flip flop buffers for sumfactorization
@@ -203,7 +199,7 @@ def evaluate_reference_gradient(element, name, restriction):
 
     # Matrices for sumfactorization
     theta = name_theta()
-    dtheta = name_dtheta()
+    dtheta = name_theta(derivative=True)
 
     # Get geometric dimension
     formdata = get_global_context_value('formdata')
diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py
index eb44a4c7..2d981f07 100644
--- a/python/dune/perftool/sumfact/sumfact.py
+++ b/python/dune/perftool/sumfact/sumfact.py
@@ -32,10 +32,8 @@ from dune.perftool.pdelab.restriction import restricted_name
 from dune.perftool.pdelab.spaces import name_lfs
 from dune.perftool.sumfact.amatrix import (AMatrix,
                                            quadrature_points_per_direction,
-                                           name_dtheta_transposed,
                                            basis_functions_per_direction,
                                            name_theta,
-                                           name_theta_transposed,
                                            )
 from dune.perftool.loopy.symbolic import SumfactKernel
 from dune.perftool.error import PerftoolError
@@ -167,17 +165,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
     # TODO covers only 2D
     for i, buf in enumerate(buffers):
         # Get the a matrices needed for this accumulation term
-        theta_transposed = name_theta_transposed()
         rows = basis_functions_per_direction()
         cols = quadrature_points_per_direction()
-        theta_matrix = AMatrix(theta_transposed, rows, cols)
+        theta_matrix = AMatrix(rows, cols, transpose=True)
 
         # If this is a gradient we need different matrices
         if accterm.argument.index:
-            dtheta_transposed = name_dtheta_transposed()
             rows = basis_functions_per_direction()
             cols = quadrature_points_per_direction()
-            dtheta_matrix = AMatrix(dtheta_transposed, rows, cols)
+            dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True)
 
             a_matrices = [theta_matrix] * dim
             a_matrices[i] = dtheta_matrix
@@ -307,7 +303,7 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional
         k = sumfact_iname(a_matrix.cols, "red")
 
         # Construct the matrix-matrix-multiplication expression a_ik*in_kj
-        prod = Product((Subscript(Variable(a_matrix.a_matrix), (Variable(i), Variable(k))),
+        prod = Product((Subscript(Variable(name_theta(transpose=a_matrix.transpose, derivative=a_matrix.derivative)), (Variable(i), Variable(k))),
                         Subscript(Variable(inp), (Variable(k), Variable(j)))
                         ))
 
-- 
GitLab