From 376251961da47ea8254ab808ae5c51e1430f18dc Mon Sep 17 00:00:00 2001 From: BigDong <yudongwang1226@gmail.com> Date: Tue, 1 Nov 2022 17:06:02 +0800 Subject: [PATCH] [Enhance] Add `ignore_keys` in ConcatDataset (#556) * [Fix] Fix ConcatDataset error in VOCDataset * minor fix * minor fix * minor fix * add UT * minor fix * minor fix * minor fix * Update mmengine/dataset/dataset_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * minor fix * Update mmengine/dataset/dataset_wrapper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/dataset/dataset_wrapper.py | 34 ++++++++++++++++++++---- tests/test_dataset/test_base_dataset.py | 35 ++++++++++++++++++++----- 2 files changed, 58 insertions(+), 11 deletions(-) diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index b91756f6..db7be63b 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset): which will be concatenated. lazy_init (bool, optional): Whether to load annotation during instantiation. Defaults to False. + ignore_keys (List[str] or str): Ignore the keys that can be + unequal in `dataset.metainfo`. Defaults to None. + `New in version 0.3.0.` """ def __init__(self, datasets: Sequence[Union[BaseDataset, dict]], - lazy_init: bool = False): + lazy_init: bool = False, + ignore_keys: Union[str, List[str], None] = None): self.datasets: List[BaseDataset] = [] for i, dataset in enumerate(datasets): if isinstance(dataset, dict): @@ -45,13 +49,33 @@ class ConcatDataset(_ConcatDataset): raise TypeError( 'elements in datasets sequence should be config or ' f'`BaseDataset` instance, but got {type(dataset)}') + if ignore_keys is None: + self.ignore_keys = [] + elif isinstance(ignore_keys, str): + self.ignore_keys = [ignore_keys] + elif isinstance(ignore_keys, list): + self.ignore_keys = ignore_keys + else: + raise TypeError('ignore_keys should be a list or str, ' + f'but got {type(ignore_keys)}') + + meta_keys: set = set() + for dataset in self.datasets: + meta_keys |= dataset.metainfo.keys() # Only use metainfo of first dataset. self._metainfo = self.datasets[0].metainfo for i, dataset in enumerate(self.datasets, 1): - if self._metainfo != dataset.metainfo: - raise ValueError( - f'The meta information of the {i}-th dataset does not ' - 'match meta information of the first dataset') + for key in meta_keys: + if key in self.ignore_keys: + continue + if key not in dataset.metainfo: + raise ValueError( + f'{key} does not in the meta information of ' + f'the {i}-th dataset') + if self._metainfo[key] != dataset.metainfo[key]: + raise ValueError( + f'The meta information of the {i}-th dataset does not ' + 'match meta information of the first dataset') self._fully_initialized = False if not lazy_init: diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index e540f527..f9a06491 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -640,6 +640,10 @@ class TestConcatDataset: with pytest.raises(TypeError): ConcatDataset(datasets=[0]) + with pytest.raises(TypeError): + ConcatDataset( + datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) + def test_full_init(self): # test init with lazy_init=True self.cat_datasets.full_init() @@ -654,14 +658,33 @@ class TestConcatDataset: with pytest.raises(NotImplementedError): self.cat_datasets.get_subset(1) - # Different meta information will raise error. + + dataset_b = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) + # Regardless of order, different meta information without + # `ignore_keys` will raise error. with pytest.raises(ValueError): - dataset_b = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=dict(classes=('cat'))) ConcatDataset(datasets=[self.dataset_a, dataset_b]) + with pytest.raises(ValueError): + ConcatDataset(datasets=[dataset_b, self.dataset_a]) + # `ignore_keys` does not contain different meta information keys will + # raise error. + with pytest.raises(ValueError): + ConcatDataset( + datasets=[self.dataset_a, dataset_b], ignore_keys=['a']) + # Different meta information with `ignore_keys` will not raise error. + cat_datasets = ConcatDataset( + datasets=[self.dataset_a, dataset_b], ignore_keys='classes') + cat_datasets.full_init() + assert len(cat_datasets) == 6 + cat_datasets.full_init() + cat_datasets._fully_initialized = False + cat_datasets[1] + assert len(cat_datasets.metainfo) == 3 + assert len(cat_datasets) == 6 def test_metainfo(self): assert self.cat_datasets.metainfo == self.dataset_a.metainfo -- GitLab