diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index 9e85f9023061817eb6787cf3d37f308a56abef0c..6f3b75174a7058dea715faa8befbdb8670d113c3 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -152,10 +152,10 @@ class BaseDataset(Dataset): ann_file (str): Annotation file path. Defaults to ''. metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. - data_root (str, optional): The root directory for ``data_prefix`` and - ``ann_file``. Defaults to None. - data_prefix (dict, optional): Prefix for training data. Defaults to - dict(img_path=None, seg_path=None). + data_root (str): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to ''. + data_prefix (dict): Prefix for training data. Defaults to + dict(img_path=''). filter_cfg (dict, optional): Config for filter data. Defaults to None. indices (int or Sequence[int], optional): Support using first few data in annotation file to facilitate training/testing on a smaller @@ -210,8 +210,8 @@ class BaseDataset(Dataset): def __init__(self, ann_file: str = '', metainfo: Optional[dict] = None, - data_root: Optional[str] = None, - data_prefix: dict = dict(img_path=None, seg_path=None), + data_root: str = '', + data_prefix: dict = dict(img_path=''), filter_cfg: Optional[dict] = None, indices: Optional[Union[int, Sequence[int]]] = None, serialize_data: bool = True, @@ -235,8 +235,7 @@ class BaseDataset(Dataset): self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) # Join paths. - if self.data_root is not None: - self._join_prefix() + self._join_prefix() # Build pipeline. self.pipeline = Compose(pipeline) @@ -534,16 +533,14 @@ class BaseDataset(Dataset): # Automatically join data directory with `self.root` if path value in # `self.data_prefix` is not an absolute path. for data_key, prefix in self.data_prefix.items(): - if prefix is None: - self.data_prefix[data_key] = self.data_root - elif isinstance(prefix, str): + if isinstance(prefix, str): if not is_abs(prefix): self.data_prefix[data_key] = osp.join( self.data_root, prefix) else: self.data_prefix[data_key] = prefix else: - raise TypeError('prefix should be a string or None, but got ' + raise TypeError('prefix should be a string, but got ' f'{type(prefix)}') @force_full_init diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index 7c575e2af80b7ce8ca2ca41212a9ffe1f07c6aea..12e637d5caab2ae9c8e64e4f48649ed1904438ec 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -51,14 +51,14 @@ class TestBaseDataset: # test the instantiation of self.base_dataset dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') @@ -68,7 +68,7 @@ class TestBaseDataset: # `serialize_data=False` dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', serialize_data=False) assert dataset._fully_initialized @@ -80,7 +80,7 @@ class TestBaseDataset: # test the instantiation of self.base_dataset with lazy init dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=True) assert not dataset._fully_initialized @@ -91,30 +91,30 @@ class TestBaseDataset: with pytest.raises(FileNotFoundError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/not_existed_annotation.json') # Use the default value of ann_file, i.e., '' with pytest.raises(TypeError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs')) + data_prefix=dict(img_path='imgs')) # test the instantiation of self.base_dataset when the ann_file is # wrong with pytest.raises(ValueError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/annotation_wrong_keys.json') with pytest.raises(TypeError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/annotation_wrong_format.json') with pytest.raises(TypeError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=['img']), + data_prefix=dict(img_path=['img']), ann_file='annotations/annotation_wrong_format.json') # test the instantiation of self.base_dataset when `parse_data_info` @@ -124,7 +124,7 @@ class TestBaseDataset: self.data_info.copy()]) dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') dataset.pipeline = self.pipeline assert dataset._fully_initialized @@ -140,20 +140,20 @@ class TestBaseDataset: BaseDataset.parse_data_info = MagicMock(return_value='xxx') dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') with pytest.raises(TypeError): BaseDataset.parse_data_info = MagicMock( return_value=[self.data_info, 'xxx']) BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') # test the instantiation of self.base_dataset without `ann_file` BaseDataset.parse_data_info = self.ori_parse_data_info dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='', serialize_data=False, lazy_init=True) @@ -164,7 +164,7 @@ class TestBaseDataset: # as the metainfo of self.base_dataset. dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset.metainfo == dict( @@ -177,7 +177,7 @@ class TestBaseDataset: dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset.metainfo == dict( dataset_type=dataset_type, @@ -189,7 +189,7 @@ class TestBaseDataset: metainfo = dict(classes=('dog', ), task_name='new_task') dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', metainfo=metainfo) assert BaseDataset.METAINFO == dict( @@ -217,7 +217,7 @@ class TestBaseDataset: osp.dirname(__file__), '../data/meta/classes.txt')) dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', metainfo=metainfo) assert dataset.metainfo == dict( @@ -232,7 +232,7 @@ class TestBaseDataset: metainfo = 'dog' dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', metainfo=metainfo) @@ -241,7 +241,7 @@ class TestBaseDataset: metainfo = dict(classes=('dog', )) dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', metainfo=metainfo, lazy_init=True) @@ -273,7 +273,7 @@ class TestBaseDataset: def test_length(self, lazy_init): dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) if not lazy_init: @@ -322,7 +322,7 @@ class TestBaseDataset: def test_getitem(self, lazy_init): dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) dataset.pipeline = self.pipeline @@ -364,7 +364,7 @@ class TestBaseDataset: def test_get_data_info(self, lazy_init): dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) @@ -406,7 +406,7 @@ class TestBaseDataset: def test_full_init(self): dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=True) dataset.pipeline = self.pipeline @@ -423,7 +423,7 @@ class TestBaseDataset: dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=False) @@ -437,11 +437,11 @@ class TestBaseDataset: # test the instantiation of self.base_dataset when passing indices dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json') dataset_sliced = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json', indices=1) assert dataset_sliced[0] == dataset[0] @@ -455,7 +455,7 @@ class TestBaseDataset: indices = 2 dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init, serialize_data=serialize_data) @@ -533,7 +533,7 @@ class TestBaseDataset: indices = 2 dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init, serialize_data=serialize_data) @@ -579,7 +579,7 @@ class TestBaseDataset: # test the instantiation of self.base_dataset when passing num_samples dataset = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img=None), + data_prefix=dict(img_path=''), ann_file='annotations/dummy_annotation.json', indices=1) assert dataset._rand_another() >= 0 @@ -598,7 +598,7 @@ class TestConcatDataset: self.dataset_a = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) @@ -608,7 +608,7 @@ class TestConcatDataset: imgs = torch.rand((2, 3, 32, 32)) self.dataset_b = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init @@ -620,7 +620,7 @@ class TestConcatDataset: dataset_cfg_b = dict( type=CustomDataset, data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b]) cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline @@ -651,7 +651,7 @@ class TestConcatDataset: with pytest.raises(ValueError): dataset_b = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json', metainfo=dict(classes=('cat'))) ConcatDataset(datasets=[self.dataset_a, dataset_b]) @@ -703,7 +703,7 @@ class TestRepeatDataset: imgs = torch.rand((2, 3, 32, 32)) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) @@ -717,7 +717,7 @@ class TestRepeatDataset: dataset_cfg = dict( type=CustomDataset, data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') repeat_dataset = RepeatDataset( dataset=dataset_cfg, times=self.repeat_times) @@ -775,7 +775,7 @@ class TestClassBalancedDataset: dataset.get_cat_ids = MagicMock(return_value=[0]) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) @@ -790,7 +790,7 @@ class TestClassBalancedDataset: dataset_cfg = dict( type=CustomDataset, data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), + data_prefix=dict(img_path='imgs'), ann_file='annotations/dummy_annotation.json') cls_banlanced_datasets = ClassBalancedDataset( dataset=dataset_cfg, oversample_thr=1e-3)