From 6eaa6ac8085341a699f22e37026be5a96f125c59 Mon Sep 17 00:00:00 2001 From: Dominic Kempf <dominic.kempf@iwr.uni-heidelberg.de> Date: Thu, 6 Apr 2017 16:34:38 +0200 Subject: [PATCH] Use basis_size/quadrature_size --- python/dune/perftool/sumfact/symbolic.py | 28 ++++++++++-------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/python/dune/perftool/sumfact/symbolic.py b/python/dune/perftool/sumfact/symbolic.py index 174e9469..cdc3aba0 100644 --- a/python/dune/perftool/sumfact/symbolic.py +++ b/python/dune/perftool/sumfact/symbolic.py @@ -174,7 +174,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def transposed(self): - return next(iter(self.matrix_sequence)).transpose + return self.matrix_sequence[0].transpose def vec_index(self, sf): """ Map an unvectorized sumfact kernel object to its position @@ -188,7 +188,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): @property def flat_input_shape(self): """ The 'flat' input tensor shape """ - return (product(mat.cols for mat in self.matrix_sequence),) + assert self.stage == 1 + return (product(mat.basis_size for mat in self.matrix_sequence),) @property def quadrature_shape(self): @@ -196,10 +197,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): Takes into account the lower dimensionality of faces and vectorization. """ - if self.transposed: - return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None) - else: - return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None) + return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None) @property def quadrature_dimtags(self): @@ -216,8 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): Takes into account vectorization. """ - shape = tuple(mat.rows for mat in self.matrix_sequence) - return shape + return tuple(mat.basis_size for mat in self.matrix_sequence) @property def dof_dimtags(self): @@ -379,6 +376,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) assert len(set(k.accumvar for k in self.kernels)) == 1 return self.kernels[0].accumvar + @property + def transposed(self): + return self.kernels[0].transposed + # # Define some properties only needed for this one # @@ -407,10 +408,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) def vectorized(self): return True - @property - def transposed(self): - return self.kernels[0].transposed - def vec_index(self, sf): return self.kernels.index(sf) @@ -423,10 +420,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def quadrature_shape(self): - if self.transposed: - return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) - else: - return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) + return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,) @property def quadrature_dimtags(self): @@ -436,7 +430,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) @property def dof_shape(self): - return tuple(mat.rows for mat in self.matrix_sequence) + (self.horizontal_width,) + return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.horizontal_width,) @property def dof_dimtags(self): -- GitLab