Skip to content
Snippets Groups Projects
Commit 6eaa6ac8 authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Use basis_size/quadrature_size

parent f6884593
No related branches found
No related tags found
No related merge requests found
...@@ -174,7 +174,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -174,7 +174,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
@property @property
def transposed(self): def transposed(self):
return next(iter(self.matrix_sequence)).transpose return self.matrix_sequence[0].transpose
def vec_index(self, sf): def vec_index(self, sf):
""" Map an unvectorized sumfact kernel object to its position """ Map an unvectorized sumfact kernel object to its position
...@@ -188,7 +188,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -188,7 +188,8 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
@property @property
def flat_input_shape(self): def flat_input_shape(self):
""" The 'flat' input tensor shape """ """ The 'flat' input tensor shape """
return (product(mat.cols for mat in self.matrix_sequence),) assert self.stage == 1
return (product(mat.basis_size for mat in self.matrix_sequence),)
@property @property
def quadrature_shape(self): def quadrature_shape(self):
...@@ -196,10 +197,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -196,10 +197,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
Takes into account the lower dimensionality of faces and vectorization. Takes into account the lower dimensionality of faces and vectorization.
""" """
if self.transposed: return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None)
return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None)
else:
return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None)
@property @property
def quadrature_dimtags(self): def quadrature_dimtags(self):
...@@ -216,8 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable): ...@@ -216,8 +214,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
Takes into account vectorization. Takes into account vectorization.
""" """
shape = tuple(mat.rows for mat in self.matrix_sequence) return tuple(mat.basis_size for mat in self.matrix_sequence)
return shape
@property @property
def dof_dimtags(self): def dof_dimtags(self):
...@@ -379,6 +376,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -379,6 +376,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
assert len(set(k.accumvar for k in self.kernels)) == 1 assert len(set(k.accumvar for k in self.kernels)) == 1
return self.kernels[0].accumvar return self.kernels[0].accumvar
@property
def transposed(self):
return self.kernels[0].transposed
# #
# Define some properties only needed for this one # Define some properties only needed for this one
# #
...@@ -407,10 +408,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -407,10 +408,6 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
def vectorized(self): def vectorized(self):
return True return True
@property
def transposed(self):
return self.kernels[0].transposed
def vec_index(self, sf): def vec_index(self, sf):
return self.kernels.index(sf) return self.kernels.index(sf)
...@@ -423,10 +420,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -423,10 +420,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def quadrature_shape(self): def quadrature_shape(self):
if self.transposed: return tuple(mat.quadrature_size for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,)
return tuple(mat.cols for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,)
else:
return tuple(mat.rows for mat in self.matrix_sequence if mat.face is None) + (self.horizontal_width,)
@property @property
def quadrature_dimtags(self): def quadrature_dimtags(self):
...@@ -436,7 +430,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable) ...@@ -436,7 +430,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property @property
def dof_shape(self): def dof_shape(self):
return tuple(mat.rows for mat in self.matrix_sequence) + (self.horizontal_width,) return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.horizontal_width,)
@property @property
def dof_dimtags(self): def dof_dimtags(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment