Skip to content
Snippets Groups Projects
Commit 09d676df authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Fixups

parent 465f90be
No related branches found
No related tags found
No related merge requests found
......@@ -191,7 +191,7 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def flat_input_shape(self):
""" The 'flat' input tensor shape """
assert self.stage == 1
return (product(mat.basis_size for mat in self.matrix_sequence),)
return (product(mat.basis_size for mat in self.matrix_sequence), 1)
@property
def quadrature_shape(self):
......@@ -271,6 +271,9 @@ class SumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable):
def horizontal_width(self):
return 1
def horizontal_index(self, _):
return 0
@property
def vertical_width(self):
return 1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment