diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 174e9469cc0c3d64283fcde1ef36a79a095fbb98..cdc3aba01bbcca3739cc1dd8a85a6ce597107d99 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -174,7 +174,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def transposed(self): - return next(iter(self.matrix_sequence)).transpose + return self.matrix_sequence[0].transpose def vec_index(self, sf): """ Map an unvectorized sumfact kernel object to its position @@ -188,7 +188,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def flat_input_shape(self): """ The 'flat' input tensor shape """ - return (product(mat.cols for mat in self.matrix_sequence),) + assert self.stage == 1 + return (product(mat.basis_size for mat in self.matrix_sequence),) @property def quadrature_shape(self): @@ -196,10 +197,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): Takes into account the lower dimensionality of faces and vectorization. """ - if self.transposed: - return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None) - else: - return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None) + return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None) @property def quadrature_dimtags(self): @@ -216,8 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): Takes into account vectorization. """ - shape = tuple(mat.rows for mat in self.matrix_sequence) - return shape + return tuple(mat.basis_size for mat in self.matrix_sequence) @property def dof_dimtags(self): @@ -379,6 +376,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) assert len(set(k.accumvar for k in self.kernels)) == 1 return self.kernels[0].accumvar + @property + def transposed(self): + return self.kernels[0].transposed + # # Define some properties only needed for this one # @@ -407,10 +408,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def vectorized(self): return True - @property - def transposed(self): - return self.kernels[0].transposed - def vec_index(self, sf): return self.kernels.index(sf) @@ -423,10 +420,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def quadrature_shape(self): - if self.transposed: - return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) - else: - return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) + return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) @property def quadrature_dimtags(self): @@ -436,7 +430,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def dof_shape(self): - return tuple(mat.rows for mat in self.matrix_sequence) + (self.horizontal_width,) + return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.horizontal_width,) @property def dof_dimtags(self):