diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 5630ca2d0a050fe4395a09476f56c4ad2f92c14a..ae485fc1c4a112270c79d32091a7397b4bd2cfb8 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -340,11 +340,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def matrix_sequence(self): - return tuple(BasisTabulationMatrixArray(quadrature_size=self.kernels[0].matrix_sequence[i].quadrature_size, - basis_size=self.kernels[0].matrix_sequence[i].basis_size, - transpose=self.kernels[0].matrix_sequence[i].transpose, - face=self.kernels[0].matrix_sequence[i].face, - derivative=tuple(k.matrix_sequence[i].derivative for k in self.kernels), + return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence[i] for k in self.kernels), width=self.vector_width, ) for i in range(self.length)) diff --git a/python/dune/perftool/sumfact/tabulation.py b/python/dune/perftool/sumfact/tabulation.py index 8fb4442409d5f47e5e9f3693a8def789923148e4..bac04f58a84cdf55de73113a86d4a1ced6e99e75 100644 --- a/python/dune/perftool/sumfact/tabulation.py +++ b/python/dune/perftool/sumfact/tabulation.py @@ -35,11 +35,11 @@ import loopy as lp import numpy -class BasisTabulationMatrixBase(ImmutableRecord): +class BasisTabulationMatrixBase(object): pass -class BasisTabulationMatrix(BasisTabulationMatrixBase): +class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord): def __init__(self, quadrature_size=None, basis_size=None, @@ -51,13 +51,13 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): quadrature_size = quadrature_points_per_direction() if basis_size is None: basis_size = basis_functions_per_direction() - BasisTabulationMatrixBase.__init__(self, - quadrature_size=quadrature_size, - basis_size=basis_size, - transpose=transpose, - derivative=derivative, - face=face, - ) + ImmutableRecord.__init__(self, + quadrature_size=quadrature_size, + basis_size=basis_size, + transpose=transpose, + derivative=derivative, + face=face, + ) @property def rows(self): @@ -81,9 +81,7 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): self.quadrature_size, self.basis_size, ) - - shape = (self.rows, self.cols) - define_theta(name, shape, self.transpose, self.derivative, face=self.face) + define_theta(name, self) return name @property @@ -92,16 +90,39 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase): class BasisTabulationMatrixArray(BasisTabulationMatrixBase): - def __init__(self, quadrature_size, basis_size, transpose, derivative, face, width): - assert isinstance(derivative, tuple) - BasisTabulationMatrixBase.__init__(self, - quadrature_size=quadrature_size, - basis_size=basis_size, - transpose=transpose, - derivative=derivative, - face=face, - width=width, - ) + def __init__(self, tabs, width=None): + assert isinstance(tabs, tuple) + + # Assert that all the basis tabulations match in size! + assert len(set(t.quadrature_size for t in tabs)) == 1 + assert len(set(t.basis_size for t in tabs)) == 1 + assert len(set(t.transpose for t in tabs)) == 1 + assert len(set(t.face for t in tabs)) == 1 + self.tabs = tabs + + if width is None: + width = len(tabs) + self.width = width + + @property + def quadrature_size(self): + return self.tabs[0].quadrature_size + + @property + def basis_size(self): + return self.tabs[0].basis_size + + @property + def transpose(self): + return self.tabs[0].transpose + + @property + def face(self): + return self.tabs[0].face + + @property + def derivative(self): + return tuple(t.derivative for t in self.tabs) @property def rows(self): @@ -124,8 +145,8 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): "x".join(tuple("d" if d else "" for d in self.derivative)), quadrature_points_per_direction(), ) - for i, d in enumerate(self.derivative): - define_theta(name, (self.rows, self.cols), self.transpose, d, face=self.face, additional_indices=(i,), width=self.width) + for i, tab in enumerate(self.tabs): + define_theta(name, tab, additional_indices=(i,), width=self.width) # Apply padding to those fields not used. This is necessary because you may get memory # initialized with NaN and those NaNs will screw the horizontal_add. @@ -271,12 +292,14 @@ def polynomial_lookup_mangler(target, func, dtypes): return CallMangleInfo(func.name, (NumpyType(numpy.float64),), (NumpyType(numpy.int32), NumpyType(numpy.float64))) -def define_theta(name, shape, transpose, derivative, face=None, additional_indices=(), width=None): +def define_theta(name, tabmat, additional_indices=(), width=None): + assert isinstance(tabmat, BasisTabulationMatrix) qp = name_oned_quadrature_points() qw = name_oned_quadrature_weights() sort_quadrature_points_weights(qp, qw) polynomials = name_polynomials() + shape = (tabmat.rows, tabmat.cols) dim_tags = "f,f" if additional_indices: dim_tags = dim_tags + ",c" @@ -291,19 +314,19 @@ def define_theta(name, shape, transpose, derivative, face=None, additional_indic potentially_vectorized=True, ) - i = theta_iname("i", shape[0]) - j = theta_iname("j", shape[1]) + i = theta_iname("i", tabmat.rows) + j = theta_iname("j", tabmat.cols) inames = i, j - if transpose: + if tabmat.transpose: inames = j, i args = [prim.Variable(inames[1]), prim.Subscript(prim.Variable(qp), (prim.Variable(inames[0]),))] - if face is not None: - args[1] = face + if tabmat.face is not None: + args[1] = tabmat.face instruction(assignee=prim.Subscript(prim.Variable(name), (prim.Variable(i), prim.Variable(j)) + additional_indices), - expression=prim.Call(PolynomialLookup(polynomials, derivative), tuple(args)), + expression=prim.Call(PolynomialLookup(polynomials, tabmat.derivative), tuple(args)), kernel="operator", )