diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index b91756f651115cabaed544de851c5df27b0ff65a..db7be63be512f96ca41936633b423b834e9cb6b5 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset): which will be concatenated. lazy_init (bool, optional): Whether to load annotation during instantiation. Defaults to False. + ignore_keys (List[str] or str): Ignore the keys that can be + unequal in `dataset.metainfo`. Defaults to None. + `New in version 0.3.0.` """ def __init__(self, datasets: Sequence[Union[BaseDataset, dict]], - lazy_init: bool = False): + lazy_init: bool = False, + ignore_keys: Union[str, List[str], None] = None): self.datasets: List[BaseDataset] = [] for i, dataset in enumerate(datasets): if isinstance(dataset, dict): @@ -45,13 +49,33 @@ class ConcatDataset(_ConcatDataset): raise TypeError( 'elements in datasets sequence should be config or ' f'`BaseDataset` instance, but got {type(dataset)}') + if ignore_keys is None: + self.ignore_keys = [] + elif isinstance(ignore_keys, str): + self.ignore_keys = [ignore_keys] + elif isinstance(ignore_keys, list): + self.ignore_keys = ignore_keys + else: + raise TypeError('ignore_keys should be a list or str, ' + f'but got {type(ignore_keys)}') + + meta_keys: set = set() + for dataset in self.datasets: + meta_keys |= dataset.metainfo.keys() # Only use metainfo of first dataset. self._metainfo = self.datasets[0].metainfo for i, dataset in enumerate(self.datasets, 1): - if self._metainfo != dataset.metainfo: - raise ValueError( - f'The meta information of the {i}-th dataset does not ' - 'match meta information of the first dataset') + for key in meta_keys: + if key in self.ignore_keys: + continue + if key not in dataset.metainfo: + raise ValueError( + f'{key} does not in the meta information of ' + f'the {i}-th dataset') + if self._metainfo[key] != dataset.metainfo[key]: + raise ValueError( + f'The meta information of the {i}-th dataset does not ' + 'match meta information of the first dataset') self._fully_initialized = False if not lazy_init: diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index e540f5276f833366ae0c54c9104eecc3321b3a56..f9a064914fa5d72c7f9bd524bc244939a6ad92f5 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -640,6 +640,10 @@ class TestConcatDataset: with pytest.raises(TypeError): ConcatDataset(datasets=[0]) + with pytest.raises(TypeError): + ConcatDataset( + datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) + def test_full_init(self): # test init with lazy_init=True self.cat_datasets.full_init() @@ -654,14 +658,33 @@ class TestConcatDataset: with pytest.raises(NotImplementedError): self.cat_datasets.get_subset(1) - # Different meta information will raise error. + + dataset_b = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) + # Regardless of order, different meta information without + # `ignore_keys` will raise error. with pytest.raises(ValueError): - dataset_b = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=dict(classes=('cat'))) ConcatDataset(datasets=[self.dataset_a, dataset_b]) + with pytest.raises(ValueError): + ConcatDataset(datasets=[dataset_b, self.dataset_a]) + # `ignore_keys` does not contain different meta information keys will + # raise error. + with pytest.raises(ValueError): + ConcatDataset( + datasets=[self.dataset_a, dataset_b], ignore_keys=['a']) + # Different meta information with `ignore_keys` will not raise error. + cat_datasets = ConcatDataset( + datasets=[self.dataset_a, dataset_b], ignore_keys='classes') + cat_datasets.full_init() + assert len(cat_datasets) == 6 + cat_datasets.full_init() + cat_datasets._fully_initialized = False + cat_datasets[1] + assert len(cat_datasets.metainfo) == 3 + assert len(cat_datasets) == 6 def test_metainfo(self): assert self.cat_datasets.metainfo == self.dataset_a.metainfo