From 376251961da47ea8254ab808ae5c51e1430f18dc Mon Sep 17 00:00:00 2001
From: BigDong <yudongwang1226@gmail.com>
Date: Tue, 1 Nov 2022 17:06:02 +0800
Subject: [PATCH] [Enhance] Add `ignore_keys` in ConcatDataset (#556)

* [Fix] Fix ConcatDataset error in VOCDataset

* minor fix

* minor fix

* minor fix

* add UT

* minor fix

* minor fix

* minor fix

* Update mmengine/dataset/dataset_wrapper.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* minor fix

* Update mmengine/dataset/dataset_wrapper.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 mmengine/dataset/dataset_wrapper.py     | 34 ++++++++++++++++++++----
 tests/test_dataset/test_base_dataset.py | 35 ++++++++++++++++++++-----
 2 files changed, 58 insertions(+), 11 deletions(-)

diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py
index b91756f6..db7be63b 100644
--- a/mmengine/dataset/dataset_wrapper.py
+++ b/mmengine/dataset/dataset_wrapper.py
@@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset):
             which will be concatenated.
         lazy_init (bool, optional): Whether to load annotation during
             instantiation. Defaults to False.
+        ignore_keys (List[str] or str): Ignore the keys that can be
+            unequal in `dataset.metainfo`. Defaults to None.
+            `New in version 0.3.0.`
     """
 
     def __init__(self,
                  datasets: Sequence[Union[BaseDataset, dict]],
-                 lazy_init: bool = False):
+                 lazy_init: bool = False,
+                 ignore_keys: Union[str, List[str], None] = None):
         self.datasets: List[BaseDataset] = []
         for i, dataset in enumerate(datasets):
             if isinstance(dataset, dict):
@@ -45,13 +49,33 @@ class ConcatDataset(_ConcatDataset):
                 raise TypeError(
                     'elements in datasets sequence should be config or '
                     f'`BaseDataset` instance, but got {type(dataset)}')
+        if ignore_keys is None:
+            self.ignore_keys = []
+        elif isinstance(ignore_keys, str):
+            self.ignore_keys = [ignore_keys]
+        elif isinstance(ignore_keys, list):
+            self.ignore_keys = ignore_keys
+        else:
+            raise TypeError('ignore_keys should be a list or str, '
+                            f'but got {type(ignore_keys)}')
+
+        meta_keys: set = set()
+        for dataset in self.datasets:
+            meta_keys |= dataset.metainfo.keys()
         # Only use metainfo of first dataset.
         self._metainfo = self.datasets[0].metainfo
         for i, dataset in enumerate(self.datasets, 1):
-            if self._metainfo != dataset.metainfo:
-                raise ValueError(
-                    f'The meta information of the {i}-th dataset does not '
-                    'match meta information of the first dataset')
+            for key in meta_keys:
+                if key in self.ignore_keys:
+                    continue
+                if key not in dataset.metainfo:
+                    raise ValueError(
+                        f'{key} does not in the meta information of '
+                        f'the {i}-th dataset')
+                if self._metainfo[key] != dataset.metainfo[key]:
+                    raise ValueError(
+                        f'The meta information of the {i}-th dataset does not '
+                        'match meta information of the first dataset')
 
         self._fully_initialized = False
         if not lazy_init:
diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py
index e540f527..f9a06491 100644
--- a/tests/test_dataset/test_base_dataset.py
+++ b/tests/test_dataset/test_base_dataset.py
@@ -640,6 +640,10 @@ class TestConcatDataset:
         with pytest.raises(TypeError):
             ConcatDataset(datasets=[0])
 
+        with pytest.raises(TypeError):
+            ConcatDataset(
+                datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1)
+
     def test_full_init(self):
         # test init with lazy_init=True
         self.cat_datasets.full_init()
@@ -654,14 +658,33 @@ class TestConcatDataset:
 
         with pytest.raises(NotImplementedError):
             self.cat_datasets.get_subset(1)
-        # Different meta information will raise error.
+
+        dataset_b = BaseDataset(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img_path='imgs'),
+            ann_file='annotations/dummy_annotation.json',
+            metainfo=dict(classes=('cat')))
+        # Regardless of order, different meta information without
+        # `ignore_keys` will raise error.
         with pytest.raises(ValueError):
-            dataset_b = BaseDataset(
-                data_root=osp.join(osp.dirname(__file__), '../data/'),
-                data_prefix=dict(img_path='imgs'),
-                ann_file='annotations/dummy_annotation.json',
-                metainfo=dict(classes=('cat')))
             ConcatDataset(datasets=[self.dataset_a, dataset_b])
+        with pytest.raises(ValueError):
+            ConcatDataset(datasets=[dataset_b, self.dataset_a])
+        # `ignore_keys` does not contain different meta information keys will
+        # raise error.
+        with pytest.raises(ValueError):
+            ConcatDataset(
+                datasets=[self.dataset_a, dataset_b], ignore_keys=['a'])
+        # Different meta information with `ignore_keys` will not raise error.
+        cat_datasets = ConcatDataset(
+            datasets=[self.dataset_a, dataset_b], ignore_keys='classes')
+        cat_datasets.full_init()
+        assert len(cat_datasets) == 6
+        cat_datasets.full_init()
+        cat_datasets._fully_initialized = False
+        cat_datasets[1]
+        assert len(cat_datasets.metainfo) == 3
+        assert len(cat_datasets) == 6
 
     def test_metainfo(self):
         assert self.cat_datasets.metainfo == self.dataset_a.metainfo
-- 
GitLab