Skip to content
Snippets Groups Projects
Commit 4a7f6ce0 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Do not store rows/cols AND transpose on BasisTabulationMatrix

Instead, store quadrature_size, basis_size and transpose
and have properties giving rows and cols.
parent 8e920583
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ from dune.perftool.generation import (backend, ...@@ -17,7 +17,7 @@ from dune.perftool.generation import (backend,
) )
from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, from dune.perftool.sumfact.tabulation import (basis_functions_per_direction,
construct_basis_matrix_sequence, construct_basis_matrix_sequence,
name_theta, BasisTabulationMatrix,
quadrature_points_per_direction, quadrature_points_per_direction,
PolynomialLookup, PolynomialLookup,
name_polynomials, name_polynomials,
...@@ -163,7 +163,7 @@ def lfs_inames(element, restriction, number=1, context=''): ...@@ -163,7 +163,7 @@ def lfs_inames(element, restriction, number=1, context=''):
@kernel_cached @kernel_cached
def evaluate_basis(element, name, restriction): def evaluate_basis(element, name, restriction):
temporary_variable(name, shape=()) temporary_variable(name, shape=())
theta = name_theta() theta = BasisTabulationMatrix().name
quad_inames = quadrature_inames() quad_inames = quadrature_inames()
inames = lfs_inames(element, restriction) inames = lfs_inames(element, restriction)
facedir = get_facedir(restriction) facedir = get_facedir(restriction)
...@@ -223,7 +223,7 @@ def evaluate_reference_gradient(element, name, restriction): ...@@ -223,7 +223,7 @@ def evaluate_reference_gradient(element, name, restriction):
prod = [] prod = []
for i in range(dim): for i in range(dim):
if i != facedir: if i != facedir:
prod.append(prim.Subscript(prim.Variable(name_theta(derivative=d == i)), prod.append(prim.Subscript(prim.Variable(BasisTabulationMatrix(derivative=d == i).name),
(prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])) (prim.Variable(quadinamemapping[i]), prim.Variable(inames[i]))
)) ))
if facedir is not None: if facedir is not None:
......
...@@ -340,8 +340,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -340,8 +340,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def matrix_sequence(self): def matrix_sequence(self):
return tuple(BasisTabulationMatrixArray(rows=self.kernels[0].matrix_sequence[i].rows, return tuple(BasisTabulationMatrixArray(quadrature_size=self.kernels[0].matrix_sequence[i].quadrature_size,
cols=self.kernels[0].matrix_sequence[i].cols, basis_size=self.kernels[0].matrix_sequence[i].basis_size,
transpose=self.kernels[0].matrix_sequence[i].transpose, transpose=self.kernels[0].matrix_sequence[i].transpose,
face=self.kernels[0].matrix_sequence[i].face, face=self.kernels[0].matrix_sequence[i].face,
derivative=tuple(k.matrix_sequence[i].derivative for k in self.kernels), derivative=tuple(k.matrix_sequence[i].derivative for k in self.kernels),
......
...@@ -40,18 +40,51 @@ class BasisTabulationMatrixBase(ImmutableRecord): ...@@ -40,18 +40,51 @@ class BasisTabulationMatrixBase(ImmutableRecord):
class BasisTabulationMatrix(BasisTabulationMatrixBase): class BasisTabulationMatrix(BasisTabulationMatrixBase):
def __init__(self, rows, cols, transpose=False, derivative=False, face=None): def __init__(self,
quadrature_size=None,
basis_size=None,
transpose=False,
derivative=False,
face=None,
):
if quadrature_size is None:
quadrature_size = quadrature_points_per_direction()
if basis_size is None:
basis_size = basis_functions_per_direction()
BasisTabulationMatrixBase.__init__(self, BasisTabulationMatrixBase.__init__(self,
rows=rows, quadrature_size=quadrature_size,
cols=cols, basis_size=basis_size,
transpose=transpose, transpose=transpose,
derivative=derivative, derivative=derivative,
face=face, face=face,
) )
@property
def rows(self):
if self.transpose:
return self.basis_size
else:
return self.quadrature_size
@property
def cols(self):
if self.transpose:
return self.quadrature_size
else:
return self.basis_size
@property @property
def name(self): def name(self):
return name_theta(self.transpose, self.derivative, face=self.face) name = "{}{}Theta{}_qp{}_dof_{}".format("face{}_".format(self.face) if self.face is not None else "",
"d" if self.derivative else "",
"T" if self.transpose else "",
self.quadrature_size,
self.basis_size,
)
shape = (self.rows, self.cols)
define_theta(name, shape, self.transpose, self.derivative, face=self.face)
return name
@property @property
def vectorized(self): def vectorized(self):
...@@ -59,17 +92,31 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): ...@@ -59,17 +92,31 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase):
class BasisTabulationMatrixArray(BasisTabulationMatrixBase): class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
def __init__(self, rows, cols, transpose, derivative, face, width): def __init__(self, quadrature_size, basis_size, transpose, derivative, face, width):
assert isinstance(derivative, tuple) assert isinstance(derivative, tuple)
BasisTabulationMatrixBase.__init__(self, BasisTabulationMatrixBase.__init__(self,
rows=rows, quadrature_size=quadrature_size,
cols=cols, basis_size=basis_size,
transpose=transpose, transpose=transpose,
derivative=derivative, derivative=derivative,
face=face, face=face,
width=width, width=width,
) )
@property
def rows(self):
if self.transpose:
return self.basis_size
else:
return self.quadrature_size
@property
def cols(self):
if self.transpose:
return self.quadrature_size
else:
return self.basis_size
@property @property
def name(self): def name(self):
name = "ThetaLarge{}{}_{}_qp{}".format("face{}_".format(self.face) if self.face is not None else "", name = "ThetaLarge{}{}_{}_qp{}".format("face{}_".format(self.face) if self.face is not None else "",
...@@ -261,40 +308,23 @@ def define_theta(name, shape, transpose, derivative, face=None, additional_indic ...@@ -261,40 +308,23 @@ def define_theta(name, shape, transpose, derivative, face=None, additional_indic
) )
def name_theta(transpose=False, derivative=False, face=None):
name = "{}{}Theta{}_qp_{}".format("face{}_".format(face) if face is not None else "",
"d" if derivative else "",
"T" if transpose else "",
quadrature_points_per_direction(),
)
shape = [quadrature_points_per_direction(), basis_functions_per_direction()]
if face is not None:
shape[0] = 1
if transpose:
shape = shape[1], shape[0]
shape = tuple(shape)
define_theta(name, shape, transpose, derivative, face=face)
return name
def construct_basis_matrix_sequence(transpose=False, derivative=None, facedir=None, facemod=None): def construct_basis_matrix_sequence(transpose=False, derivative=None, facedir=None, facemod=None):
dim = world_dimension() dim = world_dimension()
result = [None] * dim result = [None] * dim
for i in range(dim): for i in range(dim):
rows = quadrature_points_per_direction() quadrature_size = quadrature_points_per_direction()
cols = basis_functions_per_direction() basis_size = basis_functions_per_direction()
onface = None onface = None
if facedir == i: if facedir == i:
rows = 1 quadrature_size = 1
onface = facemod onface = facemod
if transpose: result[i] = BasisTabulationMatrix(quadrature_size=quadrature_size,
rows, cols = cols, rows basis_size=basis_size,
transpose=transpose,
result[i] = BasisTabulationMatrix(rows, cols, transpose=transpose, derivative=derivative == i, face=onface) derivative=derivative == i,
face=onface)
return tuple(result) return tuple(result)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment