Skip to content
Snippets Groups Projects
Unverified Commit 37625196 authored by BigDong's avatar BigDong Committed by GitHub
Browse files

[Enhance] Add `ignore_keys` in ConcatDataset (#556)


* [Fix] Fix ConcatDataset error in VOCDataset

* minor fix

* minor fix

* minor fix

* add UT

* minor fix

* minor fix

* minor fix

* Update mmengine/dataset/dataset_wrapper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* minor fix

* Update mmengine/dataset/dataset_wrapper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent da9d61f6
No related branches found
No related tags found
No related merge requests found
...@@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset): ...@@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset):
which will be concatenated. which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False. instantiation. Defaults to False.
ignore_keys (List[str] or str): Ignore the keys that can be
unequal in `dataset.metainfo`. Defaults to None.
`New in version 0.3.0.`
""" """
def __init__(self, def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]], datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False): lazy_init: bool = False,
ignore_keys: Union[str, List[str], None] = None):
self.datasets: List[BaseDataset] = [] self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets): for i, dataset in enumerate(datasets):
if isinstance(dataset, dict): if isinstance(dataset, dict):
...@@ -45,13 +49,33 @@ class ConcatDataset(_ConcatDataset): ...@@ -45,13 +49,33 @@ class ConcatDataset(_ConcatDataset):
raise TypeError( raise TypeError(
'elements in datasets sequence should be config or ' 'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}') f'`BaseDataset` instance, but got {type(dataset)}')
if ignore_keys is None:
self.ignore_keys = []
elif isinstance(ignore_keys, str):
self.ignore_keys = [ignore_keys]
elif isinstance(ignore_keys, list):
self.ignore_keys = ignore_keys
else:
raise TypeError('ignore_keys should be a list or str, '
f'but got {type(ignore_keys)}')
meta_keys: set = set()
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# Only use metainfo of first dataset. # Only use metainfo of first dataset.
self._metainfo = self.datasets[0].metainfo self._metainfo = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1): for i, dataset in enumerate(self.datasets, 1):
if self._metainfo != dataset.metainfo: for key in meta_keys:
raise ValueError( if key in self.ignore_keys:
f'The meta information of the {i}-th dataset does not ' continue
'match meta information of the first dataset') if key not in dataset.metainfo:
raise ValueError(
f'{key} does not in the meta information of '
f'the {i}-th dataset')
if self._metainfo[key] != dataset.metainfo[key]:
raise ValueError(
f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset')
self._fully_initialized = False self._fully_initialized = False
if not lazy_init: if not lazy_init:
......
...@@ -640,6 +640,10 @@ class TestConcatDataset: ...@@ -640,6 +640,10 @@ class TestConcatDataset:
with pytest.raises(TypeError): with pytest.raises(TypeError):
ConcatDataset(datasets=[0]) ConcatDataset(datasets=[0])
with pytest.raises(TypeError):
ConcatDataset(
datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1)
def test_full_init(self): def test_full_init(self):
# test init with lazy_init=True # test init with lazy_init=True
self.cat_datasets.full_init() self.cat_datasets.full_init()
...@@ -654,14 +658,33 @@ class TestConcatDataset: ...@@ -654,14 +658,33 @@ class TestConcatDataset:
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.cat_datasets.get_subset(1) self.cat_datasets.get_subset(1)
# Different meta information will raise error.
dataset_b = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=dict(classes=('cat')))
# Regardless of order, different meta information without
# `ignore_keys` will raise error.
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataset_b = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=dict(classes=('cat')))
ConcatDataset(datasets=[self.dataset_a, dataset_b]) ConcatDataset(datasets=[self.dataset_a, dataset_b])
with pytest.raises(ValueError):
ConcatDataset(datasets=[dataset_b, self.dataset_a])
# `ignore_keys` does not contain different meta information keys will
# raise error.
with pytest.raises(ValueError):
ConcatDataset(
datasets=[self.dataset_a, dataset_b], ignore_keys=['a'])
# Different meta information with `ignore_keys` will not raise error.
cat_datasets = ConcatDataset(
datasets=[self.dataset_a, dataset_b], ignore_keys='classes')
cat_datasets.full_init()
assert len(cat_datasets) == 6
cat_datasets.full_init()
cat_datasets._fully_initialized = False
cat_datasets[1]
assert len(cat_datasets.metainfo) == 3
assert len(cat_datasets) == 6
def test_metainfo(self): def test_metainfo(self):
assert self.cat_datasets.metainfo == self.dataset_a.metainfo assert self.cat_datasets.metainfo == self.dataset_a.metainfo
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment