Skip to content
Snippets Groups Projects
Commit a7950143 authored by René Heß's avatar René Heß
Browse files

[skip ci] Do not use matrix_sequence in VectorizedSumfactKernel

parent 8735bd57
No related branches found
No related tags found
No related merge requests found
......@@ -667,7 +667,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Above stringifier just calls back into this
return "VSF{}:[{}]->[{}]".format(self.stage,
", ".join(str(k.interface) for k in self.kernels),
", ".join(str(mat) for mat in self.matrix_sequence))
", ".join(str(mat) for mat in self.matrix_sequence_quadrature_permuted))
mapper_method = "map_vectorized_sumfact_kernel"
......@@ -679,7 +679,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
#
@property
def function_name(self):
return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence),
return "sfimpl_{}{}".format("_".join(str(m) for m in self.matrix_sequence_quadrature_permuted),
self.interface.function_name_suffix)
@property
......@@ -688,22 +688,25 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
Any two sum factorization kernels having the same cache_key
are realized simulatenously!
"""
return (self.matrix_sequence, self.restriction, self.stage, self.buffer)
return (self.matrix_sequence_quadrature_permuted, self.restriction, self.stage, self.buffer)
#
# Deduce all data fields of normal sum factorization kernels from the underlying kernels
#
@property
def matrix_sequence(self):
# VectorizedSumfactKernel has no knowledge about the matrix_sequence
# ordered according to directions 0,1,... since it is constructed based
# on permuted matrix sequences.
raise RuntimeError("matrix_sequence should not be used on VectorizedSumfactKernel.")
@property
def matrix_sequence_quadrature_permuted(self):
return tuple(BasisTabulationMatrixArray(tuple(k.matrix_sequence_quadrature_permuted[i] for k in self.kernels),
width=self.vector_width,
)
for i in range(self.length))
@property
def matrix_sequence_quadrature_permuted(self):
return self.matrix_sequence
@property
def matrix_sequence_cost_permuted(self):
perm = sumfact_cost_permutation_strategy(self)
......@@ -801,11 +804,11 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
quad_inames = quadrature_inames(element)
index = []
if len(self.matrix_sequence) == local_dimension():
if len(self.matrix_sequence_quadrature_permuted) == local_dimension():
for d in range(local_dimension()):
addindex = prim.Variable(quad_inames[d])
if self.matrix_sequence[d].slice_size:
if self.matrix_sequence_quadrature_permuted[d].slice_size:
addindex = addindex // self.vertical_width
index.append(addindex)
......@@ -813,10 +816,10 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
# Traverse all the quadrature inames and map them to their correct direction
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
if self.matrix_sequence_quadrature_permuted[d].face is None:
addindex = prim.Variable(quad_inames[i])
if self.matrix_sequence[d].slice_size:
if self.matrix_sequence_quadrature_permuted[d].slice_size:
addindex = addindex // self.vertical_width
index.append(addindex)
......@@ -840,13 +843,13 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
sliced = 0
if len(sf.matrix_sequence_quadrature_permuted) == local_dimension():
for d in range(local_dimension()):
if self.matrix_sequence[d].slice_size:
if self.matrix_sequence_quadrature_permuted[d].slice_size:
sliced = prim.Variable(quad_inames[d])
else:
i = 0
for d in range(world_dimension()):
if self.matrix_sequence[d].face is None:
if self.matrix_sequence[d].slice_size:
if self.matrix_sequence_quadrature_permuted[d].face is None:
if self.matrix_sequence_quadrature_permuted[d].slice_size:
sliced = prim.Variable(quad_inames[i])
i = i + 1
......@@ -854,7 +857,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def quadrature_shape(self):
return tuple(mat.quadrature_size for mat in self.matrix_sequence) + (self.vector_width,)
return tuple(mat.quadrature_size for mat in self.matrix_sequence_quadrature_permuted) + (self.vector_width,)
def quadrature_index(self, sf, visitor, direct_index=None):
quad = self._quadrature_index(sf, visitor)
......@@ -872,7 +875,7 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
@property
def dof_shape(self):
return tuple(mat.basis_size for mat in self.matrix_sequence) + (self.vector_width,)
return tuple(mat.basis_size for mat in self.matrix_sequence_quadrature_permuted) + (self.vector_width,)
@property
def dof_dimtags(self):
......@@ -907,8 +910,8 @@ class VectorizedSumfactKernel(SumfactKernelBase, ImmutableRecord, prim.Variable)
""" The total number of bytes needed from RAM for the kernel
to be executed - neglecting the existence of caches of course
"""
dofs = product(mat.basis_size for mat in self.matrix_sequence)
matrices = sum(mat.memory_traffic for mat in set(self.matrix_sequence))
dofs = product(mat.basis_size for mat in self.matrix_sequence_quadrature_permuted)
matrices = sum(mat.memory_traffic for mat in set(self.matrix_sequence_quadrature_permuted))
fbytes = get_option("precision_bits") / 8
return (dofs + matrices) * fbytes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment