From b58d4fd03a81aa47cde0b7f73236871c12c35a47 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Mon, 28 Nov 2016 17:49:47 +0100 Subject: [PATCH] Simplify theta matrix generation --- python/dune/perftool/sumfact/amatrix.py | 44 +++++++------------------ python/dune/perftool/sumfact/basis.py | 12 +++---- python/dune/perftool/sumfact/sumfact.py | 10 ++---- 3 files changed, 19 insertions(+), 47 deletions(-) diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 1c786393..41fe8ed7 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -35,15 +35,16 @@ import numpy class AMatrix(Record): - def __init__(self, a_matrix, rows, cols): + def __init__(self, rows, cols, transpose=False, derivative=False): Record.__init__(self, - a_matrix=a_matrix, rows=rows, cols=cols, + transpose=False, + derivative=False, ) def __hash__(self): - return hash((self.a_matrix, self.rows, self.cols)) + return hash((self.transpose, self.derivative, self.rows, self.cols)) def quadrature_points_per_direction(): @@ -199,33 +200,12 @@ def define_theta(name, shape, transpose, derivative): ) -def name_theta(): - name = "Theta" - shape = (quadrature_points_per_direction(), basis_functions_per_direction()) - globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") - define_theta(name, shape, False, False) - return name - - -def name_theta_transposed(): - name = "ThetaT" - shape = (basis_functions_per_direction(), quadrature_points_per_direction()) - globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") - define_theta(name, shape, True, False) - return name - - -def name_dtheta(): - name = "dTheta" - shape = (quadrature_points_per_direction(), basis_functions_per_direction()) - globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") - define_theta(name, shape, False, True) - return name - - -def name_dtheta_transposed(): - name = "dThetaT" - shape = (basis_functions_per_direction(), quadrature_points_per_direction()) +def name_theta(transpose=False, derivative=False): + name = "{}Theta{}".format("d" if derivative else "", "T" if transpose else "") + if transpose: + shape = (basis_functions_per_direction(), quadrature_points_per_direction()) + else: + shape = (quadrature_points_per_direction(), basis_functions_per_direction()) globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") - define_theta(name, shape, True, True) - return name + define_theta(name, shape, transpose, derivative) + return name \ No newline at end of file diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 10021fd5..96f24219 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -15,7 +15,6 @@ from dune.perftool.generation import (backend, ) from dune.perftool.sumfact.amatrix import (AMatrix, basis_functions_per_direction, - name_dtheta, name_theta, quadrature_points_per_direction, ) @@ -50,12 +49,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) temporary_variable(name, shape=shape, shape_impl=shape_impl) # Calculate values with sumfactorization - theta = name_theta() - dtheta = name_dtheta() rows = quadrature_points_per_direction() cols = basis_functions_per_direction() - theta_matrix = AMatrix(theta, rows, cols) - dtheta_matrix = AMatrix(dtheta, rows, cols) + theta_matrix = AMatrix(rows, cols) + dtheta_matrix = AMatrix(rows, cols, derivative=True) # TODO: # - This only covers rank 1 @@ -117,10 +114,9 @@ def pymbolic_trialfunction(element, restriction, component): dim = formdata.geometric_dimension # Setup sumfactorization - theta = name_theta() rows = quadrature_points_per_direction() cols = basis_functions_per_direction() - a_matrix = AMatrix(theta, rows, cols) + a_matrix = AMatrix(rows, cols) a_matrices = (a_matrix,) * dim # Flip flop buffers for sumfactorization @@ -203,7 +199,7 @@ def evaluate_reference_gradient(element, name, restriction): # Matrices for sumfactorization theta = name_theta() - dtheta = name_dtheta() + dtheta = name_theta(derivative=True) # Get geometric dimension formdata = get_global_context_value('formdata') diff --git a/python/dune/perftool/sumfact/sumfact.py b/python/dune/perftool/sumfact/sumfact.py index eb44a4c7..2d981f07 100644 --- a/python/dune/perftool/sumfact/sumfact.py +++ b/python/dune/perftool/sumfact/sumfact.py @@ -32,10 +32,8 @@ from dune.perftool.pdelab.restriction import restricted_name from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.sumfact.amatrix import (AMatrix, quadrature_points_per_direction, - name_dtheta_transposed, basis_functions_per_direction, name_theta, - name_theta_transposed, ) from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.error import PerftoolError @@ -167,17 +165,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): # TODO covers only 2D for i, buf in enumerate(buffers): # Get the a matrices needed for this accumulation term - theta_transposed = name_theta_transposed() rows = basis_functions_per_direction() cols = quadrature_points_per_direction() - theta_matrix = AMatrix(theta_transposed, rows, cols) + theta_matrix = AMatrix(rows, cols, transpose=True) # If this is a gradient we need different matrices if accterm.argument.index: - dtheta_transposed = name_dtheta_transposed() rows = basis_functions_per_direction() cols = quadrature_points_per_direction() - dtheta_matrix = AMatrix(dtheta_transposed, rows, cols) + dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True) a_matrices = [theta_matrix] * dim a_matrices[i] = dtheta_matrix @@ -307,7 +303,7 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional k = sumfact_iname(a_matrix.cols, "red") # Construct the matrix-matrix-multiplication expression a_ik*in_kj - prod = Product((Subscript(Variable(a_matrix.a_matrix), (Variable(i), Variable(k))), + prod = Product((Subscript(Variable(name_theta(transpose=a_matrix.transpose, derivative=a_matrix.derivative)), (Variable(i), Variable(k))), Subscript(Variable(inp), (Variable(k), Variable(j))) )) -- GitLab