diff --git a/python/dune/perftool/sumfact/basis.py b/python/dune/perftool/sumfact/basis.py index 3630d9f2750fa29874f7ee7e6bf722e581dfa0fc..4657cded8e573d3f366d38978341ef686ab9049d 100644 --- a/python/dune/perftool/sumfact/basis.py +++ b/python/dune/perftool/sumfact/basis.py @@ -17,7 +17,7 @@ from dune.perftool.generation import (backend, ) from dune.perftool.sumfact.tabulation import (basis_functions_per_direction, construct_basis_matrix_sequence, - name_theta, + BasisTabulationMatrix, quadrature_points_per_direction, PolynomialLookup, name_polynomials, @@ -163,7 +163,7 @@ def lfs_inames(element, restriction, number=1, context=''): @kernel_cached def evaluate_basis(element, name, restriction): temporary_variable(name, shape=()) - theta = name_theta() + theta = BasisTabulationMatrix().name quad_inames = quadrature_inames() inames = lfs_inames(element, restriction) facedir = get_facedir(restriction) @@ -223,7 +223,7 @@ def evaluate_reference_gradient(element, name, restriction): prod = [] for i in range(dim): if i != facedir: - prod.append(prim.Subscript(prim.Variable(name_theta(derivative=d == i)), + prod.append(prim.Subscript(prim.Variable(BasisTabulationMatrix(derivative=d == i).name), (prim.Variable(quadinamemapping[i]), prim.Variable(inames[i])) )) if facedir is not None: diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 910cb5df40a68e39c1afb5668a472608b5eef0ab..5630ca2d0a050fe4395a09476f56c4ad2f92c14a 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -340,8 +340,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def matrix_sequence(self): - return tuple(BasisTabulationMatrixArray(rows=self.kernels[0].matrix_sequence[i].rows, - cols=self.kernels[0].matrix_sequence[i].cols, + return tuple(BasisTabulationMatrixArray(quadrature_size=self.kernels[0].matrix_sequence[i].quadrature_size, + basis_size=self.kernels[0].matrix_sequence[i].basis_size, transpose=self.kernels[0].matrix_sequence[i].transpose, face=self.kernels[0].matrix_sequence[i].face, derivative=tuple(k.matrix_sequence[i].derivative for k in self.kernels), diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index e64659851c51eea8be01dd3e69f8f258c799f0ad..8fb4442409d5f47e5e9f3693a8def789923148e4 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -40,18 +40,51 @@ class BasisTabulationMatrixBase(ImmutableRecord): class BasisTabulationMatrix(BasisTabulationMatrixBase): - def __init__(self, rows, cols, transpose=False, derivative=False, face=None): + def __init__(self, + quadrature_size=None, + basis_size=None, + transpose=False, + derivative=False, + face=None, + ): + if quadrature_size is None: + quadrature_size = quadrature_points_per_direction() + if basis_size is None: + basis_size = basis_functions_per_direction() BasisTabulationMatrixBase.__init__(self, - rows=rows, - cols=cols, + quadrature_size=quadrature_size, + basis_size=basis_size, transpose=transpose, derivative=derivative, face=face, ) + @property + def rows(self): + if self.transpose: + return self.basis_size + else: + return self.quadrature_size + + @property + def cols(self): + if self.transpose: + return self.quadrature_size + else: + return self.basis_size + @property def name(self): - return name_theta(self.transpose, self.derivative, face=self.face) + name = "{}{}Theta{}_qp{}_dof_{}".format("face{}_".format(self.face) if self.face is not None else "", + "d" if self.derivative else "", + "T" if self.transpose else "", + self.quadrature_size, + self.basis_size, + ) + + shape = (self.rows, self.cols) + define_theta(name, shape, self.transpose, self.derivative, face=self.face) + return name @property def vectorized(self): @@ -59,17 +92,31 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): class BasisTabulationMatrixArray(BasisTabulationMatrixBase): - def __init__(self, rows, cols, transpose, derivative, face, width): + def __init__(self, quadrature_size, basis_size, transpose, derivative, face, width): assert isinstance(derivative, tuple) BasisTabulationMatrixBase.__init__(self, - rows=rows, - cols=cols, + quadrature_size=quadrature_size, + basis_size=basis_size, transpose=transpose, derivative=derivative, face=face, width=width, ) + @property + def rows(self): + if self.transpose: + return self.basis_size + else: + return self.quadrature_size + + @property + def cols(self): + if self.transpose: + return self.quadrature_size + else: + return self.basis_size + @property def name(self): name = "ThetaLarge{}{}_{}_qp{}".format("face{}_".format(self.face) if self.face is not None else "", @@ -261,40 +308,23 @@ def define_theta(name, shape, transpose, derivative, face=None, additional_indic ) -def name_theta(transpose=False, derivative=False, face=None): - name = "{}{}Theta{}_qp_{}".format("face{}_".format(face) if face is not None else "", - "d" if derivative else "", - "T" if transpose else "", - quadrature_points_per_direction(), - ) - - shape = [quadrature_points_per_direction(), basis_functions_per_direction()] - if face is not None: - shape[0] = 1 - if transpose: - shape = shape[1], shape[0] - shape = tuple(shape) - - define_theta(name, shape, transpose, derivative, face=face) - return name - - def construct_basis_matrix_sequence(transpose=False, derivative=None, facedir=None, facemod=None): dim = world_dimension() result = [None] * dim for i in range(dim): - rows = quadrature_points_per_direction() - cols = basis_functions_per_direction() + quadrature_size = quadrature_points_per_direction() + basis_size = basis_functions_per_direction() onface = None if facedir == i: - rows = 1 + quadrature_size = 1 onface = facemod - if transpose: - rows, cols = cols, rows - - result[i] = BasisTabulationMatrix(rows, cols, transpose=transpose, derivative=derivative == i, face=onface) + result[i] = BasisTabulationMatrix(quadrature_size=quadrature_size, + basis_size=basis_size, + transpose=transpose, + derivative=derivative == i, + face=onface) return tuple(result)