diff --git a/python/dune/perftool/sumfact/amatrix.py b/python/dune/perftool/sumfact/amatrix.py index 54303ecedeb42608d2054b9c91264ce3e973bbd7..2950782a57351d5391e6aabf83ae14c0e25b4cc0 100644 --- a/python/dune/perftool/sumfact/amatrix.py +++ b/python/dune/perftool/sumfact/amatrix.py @@ -46,9 +46,7 @@ class AMatrix(ImmutableRecord): class LargeAMatrix(ImmutableRecord): def __init__(self, rows, cols, transpose, derivative): - assert isinstance(transpose, tuple) assert isinstance(derivative, tuple) - assert len(transpose) == len(derivative) ImmutableRecord.__init__(self, rows=rows, cols=cols, @@ -227,15 +225,15 @@ def name_theta(transpose=False, derivative=False): def name_large_theta(transpose=False, derivative=False): - ident = tuple("{}{}".format("d" if d else "", "T" if t else "") for t, d in zip(transpose, derivative)) - name = "ThetaLarge_{}".format("_".join(ident)) - if True in transpose: + ident = tuple("d" if d else "" for d in derivative) + name = "ThetaLarge{}_{}".format("T" if transpose else "", "_".join(ident)) + if transpose: shape = (basis_functions_per_direction(), quadrature_points_per_direction()) else: shape = (quadrature_points_per_direction(), basis_functions_per_direction()) - for i, (t, d) in enumerate(zip(transpose, derivative)): - define_theta(name, shape, t, d, additional_indices=(i,)) + for i, d in enumerate(derivative): + define_theta(name, shape, transpose, d, additional_indices=(i,)) return loopy_class_member(name, classtag="operator", diff --git a/python/dune/perftool/sumfact/vectorization.py b/python/dune/perftool/sumfact/vectorization.py index 9155fb272bfe1ed8c3b14c8a8315a5db9893e1bc..3747a930044f67b55a4814cc441e555c44f885de 100644 --- a/python/dune/perftool/sumfact/vectorization.py +++ b/python/dune/perftool/sumfact/vectorization.py @@ -60,17 +60,15 @@ def decide_stage_vectorization_strategy(sumfacts, stage): assert len(set(tuple(sf.a_matrices[i].rows for sf in stage_sumfacts))) == 1 assert len(set(tuple(sf.a_matrices[i].cols for sf in stage_sumfacts))) == 1 - # Collect the transpose/derivative information - transpose = [False] * 4 + # Collect the derivative information derivative = [False] * 4 for sf in stage_sumfacts: - transpose[position_mapping[sf]] = sf.a_matrices[i].transpose derivative[position_mapping[sf]] = sf.a_matrices[i].derivative from dune.perftool.sumfact.amatrix import LargeAMatrix large = LargeAMatrix(rows=next(iter(stage_sumfacts)).a_matrices[i].rows, cols=next(iter(stage_sumfacts)).a_matrices[i].cols, - transpose=tuple(transpose), + transpose=next(iter(stage_sumfacts)).a_matrices[i].transpose, derivative=tuple(derivative), ) large_a_matrices.append(large)