Skip to content
Snippets Groups Projects
Commit eb035eff authored by Dominic Kempf's avatar Dominic Kempf
Browse files

Drop possibility of variation in transpose field for LargeAMatrix

It caused quite a bit of trouble with reuse of large matrices.
parent fca6404b
No related branches found
No related tags found
No related merge requests found
......@@ -46,9 +46,7 @@ class AMatrix(ImmutableRecord):
class LargeAMatrix(ImmutableRecord):
def __init__(self, rows, cols, transpose, derivative):
assert isinstance(transpose, tuple)
assert isinstance(derivative, tuple)
assert len(transpose) == len(derivative)
ImmutableRecord.__init__(self,
rows=rows,
cols=cols,
......@@ -227,15 +225,15 @@ def name_theta(transpose=False, derivative=False):
def name_large_theta(transpose=False, derivative=False):
ident = tuple("{}{}".format("d" if d else "", "T" if t else "") for t, d in zip(transpose, derivative))
name = "ThetaLarge_{}".format("_".join(ident))
if True in transpose:
ident = tuple("d" if d else "" for d in derivative)
name = "ThetaLarge{}_{}".format("T" if transpose else "", "_".join(ident))
if transpose:
shape = (basis_functions_per_direction(), quadrature_points_per_direction())
else:
shape = (quadrature_points_per_direction(), basis_functions_per_direction())
for i, (t, d) in enumerate(zip(transpose, derivative)):
define_theta(name, shape, t, d, additional_indices=(i,))
for i, d in enumerate(derivative):
define_theta(name, shape, transpose, d, additional_indices=(i,))
return loopy_class_member(name,
classtag="operator",
......
......@@ -60,17 +60,15 @@ def decide_stage_vectorization_strategy(sumfacts, stage):
assert len(set(tuple(sf.a_matrices[i].rows for sf in stage_sumfacts))) == 1
assert len(set(tuple(sf.a_matrices[i].cols for sf in stage_sumfacts))) == 1
# Collect the transpose/derivative information
transpose = [False] * 4
# Collect the derivative information
derivative = [False] * 4
for sf in stage_sumfacts:
transpose[position_mapping[sf]] = sf.a_matrices[i].transpose
derivative[position_mapping[sf]] = sf.a_matrices[i].derivative
from dune.perftool.sumfact.amatrix import LargeAMatrix
large = LargeAMatrix(rows=next(iter(stage_sumfacts)).a_matrices[i].rows,
cols=next(iter(stage_sumfacts)).a_matrices[i].cols,
transpose=tuple(transpose),
transpose=next(iter(stage_sumfacts)).a_matrices[i].transpose,
derivative=tuple(derivative),
)
large_a_matrices.append(large)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment