Skip to content
Snippets Groups Projects
Commit dec5228d authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Correctly generate names for combined theta matrices

fixes a nasty bug, where face information was omitted
parent 989e0163
No related branches found
No related tags found
No related merge requests found
...@@ -62,14 +62,27 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord): ...@@ -62,14 +62,27 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord):
slice_index=slice_index, slice_index=slice_index,
) )
@property
def _shortname(self):
infos = ["d{}".format(self.basis_size),
"q{}".format(self.quadrature_size)]
if self.transpose:
infos.append("T")
if self.derivative:
infos.append("dx")
if self.face is not None:
infos.append("f{}".format(self.face))
if self.slice_size is not None:
infos.append("s{}".format(self.slice_index))
return "".join(infos)
def __str__(self): def __str__(self):
return "{}{}A{}{}{}" \ return "Theta_{}".format(self._shortname)
.format("face{}_".format(self.face) if self.face is not None else "",
"d" if self.derivative else "",
self.basis_size,
"T" if self.transpose else "",
"_slice{}".format(self.slice_index) if self.slice_size is not None else "",
)
@property @property
def rows(self): def rows(self):
...@@ -96,14 +109,7 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord): ...@@ -96,14 +109,7 @@ class BasisTabulationMatrix(BasisTabulationMatrixBase, ImmutableRecord):
return size return size
def pymbolic(self, indices): def pymbolic(self, indices):
name = "{}{}Theta{}{}_qp{}_dof{}" \ name = str(self)
.format("face{}_".format(self.face) if self.face is not None else "",
"d" if self.derivative else "",
"T" if self.transpose else "",
"_slice{}".format(self.slice_index) if self.slice_size is not None else "",
self.quadrature_size,
self.basis_size,
)
define_theta(name, self) define_theta(name, self)
return prim.Subscript(prim.Variable(name), indices) return prim.Subscript(prim.Variable(name), indices)
...@@ -140,11 +146,7 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): ...@@ -140,11 +146,7 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
self.width = width self.width = width
def __str__(self): def __str__(self):
abbrevs = tuple("{}A{}{}".format("d" if t.derivative else "", return "_".join((t._shortname for t in self.tabs))
self.basis_size,
"s{}".format(t.slice_index) if t.slice_size is not None else "")
for t in self.tabs)
return "_".join(abbrevs)
@property @property
def quadrature_size(self): def quadrature_size(self):
...@@ -196,15 +198,8 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase): ...@@ -196,15 +198,8 @@ class BasisTabulationMatrixArray(BasisTabulationMatrixBase):
theta = self.tabs[0].pymbolic(indices[:-1]) theta = self.tabs[0].pymbolic(indices[:-1])
return prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=get_vcl_type_size(dtype_floatingpoint())), (theta,)) return prim.Call(ExplicitVCLCast(dtype_floatingpoint(), vector_width=get_vcl_type_size(dtype_floatingpoint())), (theta,))
abbrevs = tuple("{}x{}".format("d" if t.derivative else "", name = str(self)
"s{}".format(t.slice_index) if t.slice_size is not None else "")
for t in self.tabs)
name = "ThetaLarge{}{}_{}_qp{}_dof{}".format("face{}_".format(self.face) if self.face is not None else "",
"T" if self.transpose else "",
"_".join(abbrevs),
self.tabs[0].quadrature_size,
self.tabs[0].basis_size,
)
for i, tab in enumerate(self.tabs): for i, tab in enumerate(self.tabs):
define_theta(name, tab, additional_indices=(i,), width=self.width) define_theta(name, tab, additional_indices=(i,), width=self.width)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment