Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from unittest.mock import MagicMock
import pytest
import torch
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import DATASETS, TRANSFORMS
def function_pipeline(data_info):
return data_info
@TRANSFORMS.register_module()
class CallableTransform:
def __call__(self, data_info):
return data_info
@TRANSFORMS.register_module()
class NotCallableTransform:
pass
@DATASETS.register_module()
class CustomDataset(BaseDataset):
pass
class TestBaseDataset:
dataset_type = BaseDataset
data_info = dict(
filename='test_img.jpg', height=604, width=640, sample_idx=0)
imgs = torch.rand((2, 3, 32, 32))
pipeline = MagicMock(return_value=dict(imgs=imgs))
METAINFO: dict = dict()
parse_data_info = MagicMock(return_value=data_info)
self.dataset_type.METAINFO = self.METAINFO
self.dataset_type.parse_data_info = self.parse_data_info
# test the instantiation of self.base_dataset
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json')
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
# test the instantiation of self.base_dataset with
# `serialize_data=False`
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
serialize_data=False)
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert not hasattr(dataset, 'data_address')
assert len(dataset) == 3
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset with lazy init
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=True)
assert not dataset._fully_initialized
assert not dataset.data_list
# test the instantiation of self.base_dataset if ann_file is not
# existed.
with pytest.raises(FileNotFoundError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/not_existed_annotation.json')
# Use the default value of ann_file, i.e., ''
with pytest.raises(FileNotFoundError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'))
# test the instantiation of self.base_dataset when the ann_file is
# wrong
with pytest.raises(ValueError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/annotation_wrong_keys.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/annotation_wrong_format.json')
with pytest.raises(TypeError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=['img']),
ann_file='annotations/annotation_wrong_format.json')
# test the instantiation of self.base_dataset when `parse_data_info`
self.dataset_type.parse_data_info = MagicMock(
return_value=[self.data_info,
self.data_info.copy()])
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert hasattr(dataset, 'data_list')
assert len(dataset) == 6
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset when `parse_data_info`
# return unsupported data.
with pytest.raises(TypeError):
self.dataset_type.parse_data_info = MagicMock(return_value='xxx')
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
with pytest.raises(TypeError):
self.dataset_type.parse_data_info = MagicMock(
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
# test dataset.metainfo with setting the metainfo from annotation file
# as the metainfo of self.base_dataset.
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.metainfo == dict(
dataset_type='test_dataset', task_name='test_task', empty_list=[])
# test dataset.metainfo with setting METAINFO in self.base_dataset
self.dataset_type.METAINFO = dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='test_task',
# test dataset.metainfo with passing metainfo into self.base_dataset
metainfo = dict(classes=('dog', ), task_name='new_task')
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert self.dataset_type.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
task_name='new_task',
classes=('dog', ),
empty_list=[])
# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
# change
self.dataset_type.METAINFO['classes'] = ('dog', 'cat', 'fish')
assert self.dataset_type.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# test dataset.metainfo with passing metainfo containing a file into
metainfo = dict(
classes=osp.join(
osp.dirname(__file__), '../data/meta/classes.txt'))
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='test_task',
# test dataset.metainfo with passing unsupported metainfo into
# self.base_dataset
with pytest.raises(TypeError):
metainfo = 'dog'
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
# test dataset.metainfo with passing metainfo into self.base_dataset
# and lazy_init is True
metainfo = dict(classes=('dog', ))
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo,
# 'task_name' and 'empty_list' not in dataset.metainfo
assert dataset.metainfo == dict(
dataset_type=dataset_type, classes=('dog', ))
# test whether self.base_dataset.METAINFO is changed when a customize
# test reset METAINFO in ToyDataset.
METAINFO = dict(xxx='xxx')
assert ToyDataset.METAINFO == dict(xxx='xxx')
assert self.dataset_type.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
# test update METAINFO in ToyDataset.
METAINFO = copy.deepcopy(self.dataset_type.METAINFO)
METAINFO['classes'] = ('bird', )
assert ToyDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('bird', ))
assert self.dataset_type.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
@pytest.mark.parametrize('lazy_init', [True, False])
def test_length(self, lazy_init):
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert len(dataset) == 3
else:
# test `__len__()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_list
assert len(dataset) == 3
assert hasattr(dataset, 'data_list')
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def test_compose(self):
# test callable transform
transforms = [function_pipeline]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test transform build from cfg_dict
transforms = [dict(type='CallableTransform')]
compose = Compose(transforms=transforms)
assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
# test return None in advance
none_func = MagicMock(return_value=None)
transforms = [none_func, function_pipeline]
compose = Compose(transforms=transforms)
assert compose(dict(img=self.imgs)) is None
# test repr
repr_str = f'Compose(\n' \
f' {none_func}\n' \
f' {function_pipeline}\n' \
f')'
assert repr(compose) == repr_str
# non-callable transform will raise error
with pytest.raises(TypeError):
transforms = [dict(type='NotCallableTransform')]
Compose(transforms)
# transform must be callable or dict
with pytest.raises(TypeError):
Compose([1])
@pytest.mark.parametrize('lazy_init', [True, False])
def test_getitem(self, lazy_init):
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert dataset[0] == dict(imgs=self.imgs)
else:
# Test `__getitem__()` when lazy_init is True
assert not dataset.data_list
# Call `full_init()` automatically
assert dataset[0] == dict(imgs=self.imgs)
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
# Test with test mode
dataset.test_mode = False
# Test cannot get a valid image.
dataset.prepare_data = MagicMock(return_value=None)
with pytest.raises(Exception):
dataset[0]
# Test get valid image by `_rand_another`
def fake_prepare_data(idx):
if idx == 0:
return None
else:
return 1
dataset.prepare_data = fake_prepare_data
dataset[0]
dataset.test_mode = True
with pytest.raises(Exception):
dataset[0]
@pytest.mark.parametrize('lazy_init', [True, False])
def test_get_data_info(self, lazy_init):
self._init_dataset()
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init)
if not lazy_init:
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert dataset.get_data_info(0) == self.data_info
else:
# test `get_data_info()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_list
# call `full_init()` automatically
assert dataset.get_data_info(0) == self.data_info
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
def test_force_full_init(self):
with pytest.raises(AttributeError):
class ClassWithoutFullInit:
@force_full_init
def foo(self):
pass
class_without_full_init = ClassWithoutFullInit()
class_without_full_init.foo()
def test_full_init(self):
self._init_dataset()
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=True)
# test `full_init()` when lazy_init is True
assert not dataset._fully_initialized
assert not dataset.data_list
# call `full_init()` manually
dataset.full_init()
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert len(dataset) == 3
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
lazy_init=False)
dataset.pipeline = self.pipeline
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
assert len(dataset) == 3
assert dataset[0] == dict(imgs=self.imgs)
assert dataset.get_data_info(0) == self.data_info
# test the instantiation of self.base_dataset when passing indices
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json')
dataset_sliced = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
indices=1)
assert dataset_sliced[0] == dataset[0]
assert len(dataset_sliced) == 1
@pytest.mark.parametrize(
'lazy_init, serialize_data',
([True, False], [False, True], [True, True], [False, False]))
def test_get_subset_(self, lazy_init, serialize_data):
# Test positive int indices.
indices = 2
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
lazy_init=lazy_init,
serialize_data=serialize_data)
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(indices)
assert len(dataset_copy) == 2
for i in range(len(dataset_copy)):
ori_data = dataset[i]
assert dataset_copy[i] == ori_data
# Test negative int indices.
indices = -2
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(indices)
assert len(dataset_copy) == 2
for i in range(len(dataset_copy)):
ori_data = dataset[i + 1]
ori_data['sample_idx'] = i
assert dataset_copy[i] == ori_data
# If indices is 0, return empty dataset.
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(0)
assert len(dataset_copy) == 0
# Test list indices with positive element.
indices = [1]
dataset_copy = copy.deepcopy(dataset)
ori_data = dataset[1]
ori_data['sample_idx'] = 0
dataset_copy.get_subset_(indices)
assert len(dataset_copy) == 1
assert dataset_copy[0] == ori_data
# Test list indices with negative element.
indices = [-1]
dataset_copy = copy.deepcopy(dataset)
ori_data = dataset[2]
ori_data['sample_idx'] = 0
dataset_copy.get_subset_(indices)
assert len(dataset_copy) == 1
assert dataset_copy[0] == ori_data
# Test empty list.
indices = []
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(indices)
assert len(dataset_copy) == 0
# Test list with multiple positive indices.
indices = [0, 1, 2]
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(indices)
for i in range(len(dataset_copy)):
ori_data = dataset[i]
ori_data['sample_idx'] = i
assert dataset_copy[i] == ori_data
# Test list with multiple negative indices.
indices = [-1, -2, 0]
dataset_copy = copy.deepcopy(dataset)
dataset_copy.get_subset_(indices)
for i in range(len(dataset_copy)):
ori_data = dataset[len(dataset) - i - 1]
ori_data['sample_idx'] = i
assert dataset_copy[i] == ori_data
with pytest.raises(TypeError):
dataset.get_subset_(dict())
@pytest.mark.parametrize(
'lazy_init, serialize_data',
([True, False], [False, True], [True, True], [False, False]))
def test_get_subset(self, lazy_init, serialize_data):
# Test positive indices.
indices = 2
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
lazy_init=lazy_init,
serialize_data=serialize_data)
dataset_sliced = dataset.get_subset(indices)
assert len(dataset_sliced) == 2
assert dataset_sliced[0] == dataset[0]
for i in range(len(dataset_sliced)):
assert dataset_sliced[i] == dataset[i]
# Test negative indices.
indices = -2
dataset_sliced = dataset.get_subset(indices)
assert len(dataset_sliced) == 2
for i in range(len(dataset_sliced)):
ori_data = dataset[i + 1]
ori_data['sample_idx'] = i
assert dataset_sliced[i] == ori_data
# If indices is 0 or empty list, return empty dataset.
assert len(dataset.get_subset(0)) == 0
assert len(dataset.get_subset([])) == 0
# test list indices.
indices = [1]
dataset_sliced = dataset.get_subset(indices)
ori_data = dataset[1]
ori_data['sample_idx'] = 0
assert len(dataset_sliced) == 1
assert dataset_sliced[0] == ori_data
# Test list with multiple positive index.
indices = [0, 1, 2]
dataset_sliced = dataset.get_subset(indices)
for i in range(len(dataset_sliced)):
ori_data = dataset[i]
ori_data['sample_idx'] = i
assert dataset_sliced[i] == ori_data
# Test list with multiple negative index.
indices = [-1, -2, 0]
dataset_sliced = dataset.get_subset(indices)
for i in range(len(dataset_sliced)):
ori_data = dataset[len(dataset) - i - 1]
ori_data['sample_idx'] = i
assert dataset_sliced[i] == ori_data
def test_rand_another(self):
# test the instantiation of self.base_dataset when passing num_samples
dataset = self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img=None),
ann_file='annotations/dummy_annotation.json',
indices=1)
assert dataset._rand_another() >= 0
assert dataset._rand_another() < len(dataset)
dataset = BaseDataset
# create dataset_a
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_data_info = MagicMock(return_value=data_info)
self.dataset_a = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
# create dataset_b
data_info = dict(filename='gray.jpg', height=288, width=512)
dataset.parse_data_info = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
self.dataset_b = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
# test init
self.cat_datasets = ConcatDataset(
datasets=[self.dataset_a, self.dataset_b])
def test_init(self):
# Test build dataset from cfg.
dataset_cfg_b = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b])
cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline
assert len(cat_datasets) == len(self.cat_datasets)
for i in range(len(cat_datasets)):
assert (cat_datasets.get_data_info(i) ==
self.cat_datasets.get_data_info(i))
assert (cat_datasets[i] == self.cat_datasets[i])
with pytest.raises(TypeError):
ConcatDataset(datasets=[0])
def test_full_init(self):
# test init with lazy_init=True
self.cat_datasets.full_init()
assert len(self.cat_datasets) == 6
self.cat_datasets.full_init()
self.cat_datasets._fully_initialized = False
self.cat_datasets[1]
assert len(self.cat_datasets) == 6
with pytest.raises(NotImplementedError):
self.cat_datasets.get_subset_(1)
with pytest.raises(NotImplementedError):
self.cat_datasets.get_subset(1)
# Different meta information will raise error.
with pytest.raises(ValueError):
dataset_b = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=dict(classes=('cat')))
ConcatDataset(datasets=[self.dataset_a, dataset_b])
def test_metainfo(self):
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
def test_length(self):
assert len(self.cat_datasets) == (
len(self.dataset_a) + len(self.dataset_b))
def test_getitem(self):
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
self.dataset_b[0]['imgs']).all()
assert (
self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
assert (self.cat_datasets[-1]['imgs'] !=
self.dataset_a[-1]['imgs']).all()
def test_get_data_info(self):
assert self.cat_datasets.get_data_info(
0) == self.dataset_a.get_data_info(0)
assert self.cat_datasets.get_data_info(
0) != self.dataset_b.get_data_info(0)
assert self.cat_datasets.get_data_info(
-1) == self.dataset_b.get_data_info(-1)
assert self.cat_datasets.get_data_info(
-1) != self.dataset_a.get_data_info(-1)
def test_get_ori_dataset_idx(self):
assert self.cat_datasets._get_ori_dataset_idx(3) == (
1, 3 - len(self.dataset_a))
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
1, len(self.dataset_b) - 1)
with pytest.raises(ValueError):
assert self.cat_datasets._get_ori_dataset_idx(-10)
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_data_info = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_times = 5
# test init
self.repeat_datasets = RepeatDataset(
dataset=self.dataset, times=self.repeat_times)
def test_init(self):
# Test build dataset from cfg.
dataset_cfg = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
repeat_dataset = RepeatDataset(
dataset=dataset_cfg, times=self.repeat_times)
repeat_dataset.dataset.pipeline = self.dataset.pipeline
assert len(repeat_dataset) == len(self.repeat_datasets)
for i in range(len(repeat_dataset)):
assert (repeat_dataset.get_data_info(i) ==
self.repeat_datasets.get_data_info(i))
assert (repeat_dataset[i] == self.repeat_datasets[i])
with pytest.raises(TypeError):
RepeatDataset(dataset=[0], times=5)
def test_full_init(self):
self.repeat_datasets.full_init()
assert len(
self.repeat_datasets) == self.repeat_times * len(self.dataset)
self.repeat_datasets.full_init()
self.repeat_datasets._fully_initialized = False
self.repeat_datasets[1]
assert len(self.repeat_datasets) == \
self.repeat_times * len(self.dataset)
with pytest.raises(NotImplementedError):
self.repeat_datasets.get_subset_(1)
with pytest.raises(NotImplementedError):
self.repeat_datasets.get_subset(1)
def test_metainfo(self):
assert self.repeat_datasets.metainfo == self.dataset.metainfo
def test_length(self):
assert len(
self.repeat_datasets) == len(self.dataset) * self.repeat_times
def test_getitem(self):
for i in range(self.repeat_times):
assert self.repeat_datasets[len(self.dataset) *
i] == self.dataset[0]
def test_get_data_info(self):
for i in range(self.repeat_times):
assert self.repeat_datasets.get_data_info(
len(self.dataset) * i) == self.dataset.get_data_info(0)
class TestClassBalancedDataset:
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_data_info = MagicMock(return_value=data_info)
imgs = torch.rand((2, 3, 32, 32))
dataset.get_cat_ids = MagicMock(return_value=[0])
self.dataset = dataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
self.repeat_indices = [0, 0, 1, 1, 1]
# test init
self.cls_banlanced_datasets = ClassBalancedDataset(
dataset=self.dataset, oversample_thr=1e-3)
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
def test_init(self):
# Test build dataset from cfg.
dataset_cfg = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
cls_banlanced_datasets = ClassBalancedDataset(
dataset=dataset_cfg, oversample_thr=1e-3)
cls_banlanced_datasets.repeat_indices = self.repeat_indices
cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline
assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets)
for i in range(len(cls_banlanced_datasets)):
assert (cls_banlanced_datasets.get_data_info(i) ==
self.cls_banlanced_datasets.get_data_info(i))
assert (
cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i])
with pytest.raises(TypeError):
ClassBalancedDataset(dataset=[0], times=5)
self.cls_banlanced_datasets.full_init()
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
self.cls_banlanced_datasets._fully_initialized = False
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
assert len(self.cls_banlanced_datasets) != len(self.repeat_indices)
with pytest.raises(NotImplementedError):
self.cls_banlanced_datasets.get_subset_(1)
with pytest.raises(NotImplementedError):
self.cls_banlanced_datasets.get_subset(1)
def test_metainfo(self):
assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo
def test_length(self):
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
def test_getitem(self):
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets[i] == self.dataset[
self.repeat_indices[i]]
def test_get_data_info(self):
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_data_info(
i) == self.dataset.get_data_info(self.repeat_indices[i])
def test_get_cat_ids(self):
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_cat_ids(
i) == self.dataset.get_cat_ids(self.repeat_indices[i])