Skip to content
Snippets Groups Projects
Commit b58d4fd0 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Simplify theta matrix generation

parent e57ee41b
No related branches found
No related tags found
No related merge requests found
...@@ -35,15 +35,16 @@ import numpy ...@@ -35,15 +35,16 @@ import numpy
class AMatrix(Record): class AMatrix(Record):
def __init__(self, a_matrix, rows, cols): def __init__(self, rows, cols, transpose=False, derivative=False):
Record.__init__(self, Record.__init__(self,
a_matrix=a_matrix,
rows=rows, rows=rows,
cols=cols, cols=cols,
transpose=False,
derivative=False,
) )
def __hash__(self): def __hash__(self):
return hash((self.a_matrix, self.rows, self.cols)) return hash((self.transpose, self.derivative, self.rows, self.cols))
def quadrature_points_per_direction(): def quadrature_points_per_direction():
...@@ -199,33 +200,12 @@ def define_theta(name, shape, transpose, derivative): ...@@ -199,33 +200,12 @@ def define_theta(name, shape, transpose, derivative):
) )
def name_theta(): def name_theta(transpose=False, derivative=False):
name = "Theta" name = "{}Theta{}".format("d" if derivative else "", "T" if transpose else "")
shape = (quadrature_points_per_direction(), basis_functions_per_direction()) if transpose:
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") shape = (basis_functions_per_direction(), quadrature_points_per_direction())
define_theta(name, shape, False, False) else:
return name shape = (quadrature_points_per_direction(), basis_functions_per_direction())
def name_theta_transposed():
name = "ThetaT"
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
define_theta(name, shape, True, False)
return name
def name_dtheta():
name = "dTheta"
shape = (quadrature_points_per_direction(), basis_functions_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
define_theta(name, shape, False, True)
return name
def name_dtheta_transposed():
name = "dThetaT"
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f") globalarg(name, shape=shape, dtype=numpy.float64, dim_tags="f,f")
define_theta(name, shape, True, True) define_theta(name, shape, transpose, derivative)
return name return name
\ No newline at end of file
...@@ -15,7 +15,6 @@ from dune.perftool.generation import (backend, ...@@ -15,7 +15,6 @@ from dune.perftool.generation import (backend,
) )
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
basis_functions_per_direction, basis_functions_per_direction,
name_dtheta,
name_theta, name_theta,
quadrature_points_per_direction, quadrature_points_per_direction,
) )
...@@ -50,12 +49,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component) ...@@ -50,12 +49,10 @@ def sumfact_evaluate_coefficient_gradient(element, name, restriction, component)
temporary_variable(name, shape=shape, shape_impl=shape_impl) temporary_variable(name, shape=shape, shape_impl=shape_impl)
# Calculate values with sumfactorization # Calculate values with sumfactorization
theta = name_theta()
dtheta = name_dtheta()
rows = quadrature_points_per_direction() rows = quadrature_points_per_direction()
cols = basis_functions_per_direction() cols = basis_functions_per_direction()
theta_matrix = AMatrix(theta, rows, cols) theta_matrix = AMatrix(rows, cols)
dtheta_matrix = AMatrix(dtheta, rows, cols) dtheta_matrix = AMatrix(rows, cols, derivative=True)
# TODO: # TODO:
# - This only covers rank 1 # - This only covers rank 1
...@@ -117,10 +114,9 @@ def pymbolic_trialfunction(element, restriction, component): ...@@ -117,10 +114,9 @@ def pymbolic_trialfunction(element, restriction, component):
dim = formdata.geometric_dimension dim = formdata.geometric_dimension
# Setup sumfactorization # Setup sumfactorization
theta = name_theta()
rows = quadrature_points_per_direction() rows = quadrature_points_per_direction()
cols = basis_functions_per_direction() cols = basis_functions_per_direction()
a_matrix = AMatrix(theta, rows, cols) a_matrix = AMatrix(rows, cols)
a_matrices = (a_matrix,) * dim a_matrices = (a_matrix,) * dim
# Flip flop buffers for sumfactorization # Flip flop buffers for sumfactorization
...@@ -203,7 +199,7 @@ def evaluate_reference_gradient(element, name, restriction): ...@@ -203,7 +199,7 @@ def evaluate_reference_gradient(element, name, restriction):
# Matrices for sumfactorization # Matrices for sumfactorization
theta = name_theta() theta = name_theta()
dtheta = name_dtheta() dtheta = name_theta(derivative=True)
# Get geometric dimension # Get geometric dimension
formdata = get_global_context_value('formdata') formdata = get_global_context_value('formdata')
......
...@@ -32,10 +32,8 @@ from dune.perftool.pdelab.restriction import restricted_name ...@@ -32,10 +32,8 @@ from dune.perftool.pdelab.restriction import restricted_name
from dune.perftool.pdelab.spaces import name_lfs from dune.perftool.pdelab.spaces import name_lfs
from dune.perftool.sumfact.amatrix import (AMatrix, from dune.perftool.sumfact.amatrix import (AMatrix,
quadrature_points_per_direction, quadrature_points_per_direction,
name_dtheta_transposed,
basis_functions_per_direction, basis_functions_per_direction,
name_theta, name_theta,
name_theta_transposed,
) )
from dune.perftool.loopy.symbolic import SumfactKernel from dune.perftool.loopy.symbolic import SumfactKernel
from dune.perftool.error import PerftoolError from dune.perftool.error import PerftoolError
...@@ -167,17 +165,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id): ...@@ -167,17 +165,15 @@ def generate_accumulation_instruction(visitor, accterm, measure, subdomain_id):
# TODO covers only 2D # TODO covers only 2D
for i, buf in enumerate(buffers): for i, buf in enumerate(buffers):
# Get the a matrices needed for this accumulation term # Get the a matrices needed for this accumulation term
theta_transposed = name_theta_transposed()
rows = basis_functions_per_direction() rows = basis_functions_per_direction()
cols = quadrature_points_per_direction() cols = quadrature_points_per_direction()
theta_matrix = AMatrix(theta_transposed, rows, cols) theta_matrix = AMatrix(rows, cols, transpose=True)
# If this is a gradient we need different matrices # If this is a gradient we need different matrices
if accterm.argument.index: if accterm.argument.index:
dtheta_transposed = name_dtheta_transposed()
rows = basis_functions_per_direction() rows = basis_functions_per_direction()
cols = quadrature_points_per_direction() cols = quadrature_points_per_direction()
dtheta_matrix = AMatrix(dtheta_transposed, rows, cols) dtheta_matrix = AMatrix(rows, cols, transpose=True, derivative=True)
a_matrices = [theta_matrix] * dim a_matrices = [theta_matrix] * dim
a_matrices[i] = dtheta_matrix a_matrices[i] = dtheta_matrix
...@@ -307,7 +303,7 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional ...@@ -307,7 +303,7 @@ def sum_factorization_kernel(a_matrices, buf, insn_dep=frozenset({}), additional
k = sumfact_iname(a_matrix.cols, "red") k = sumfact_iname(a_matrix.cols, "red")
# Construct the matrix-matrix-multiplication expression a_ik*in_kj # Construct the matrix-matrix-multiplication expression a_ik*in_kj
prod = Product((Subscript(Variable(a_matrix.a_matrix), (Variable(i), Variable(k))), prod = Product((Subscript(Variable(name_theta(transpose=a_matrix.transpose, derivative=a_matrix.derivative)), (Variable(i), Variable(k))),
Subscript(Variable(inp), (Variable(k), Variable(j))) Subscript(Variable(inp), (Variable(k), Variable(j)))
)) ))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment