diff --git a/tests/test_evaluator/test_base_evaluator.py b/tests/test_evaluator/test_base_evaluator.py index bed31b1f21ae825ca0f8558075554a39375a5477..042d2fb8a0a1298f7dec93e59c276718349f7368 100644 --- a/tests/test_evaluator/test_base_evaluator.py +++ b/tests/test_evaluator/test_base_evaluator.py @@ -79,10 +79,9 @@ def generate_test_results(size, batch_size, pred, label): bs_residual = size % batch_size for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size - data_batch = [(np.zeros( - (3, 10, 10)), BaseDataElement(data={'label': label})) + data_batch = [(np.zeros((3, 10, 10)), BaseDataElement(label=label)) for _ in range(bs)] - predictions = [BaseDataElement(data={'pred': pred}) for _ in range(bs)] + predictions = [BaseDataElement(pred=pred) for _ in range(bs)] yield (data_batch, predictions)