From db2b45b7ac6ca1abb64ff364329a69a1f06974ec Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 27 Apr 2022 19:48:28 +0800
Subject: [PATCH] [Fix] Fix the probable failure of the basedataset unit test
 (#210)

---
 tests/test_data/test_base_dataset.py | 106 +++++++++++++--------------
 1 file changed, 51 insertions(+), 55 deletions(-)

diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py
index f82be769..d53ff91c 100644
--- a/tests/test_data/test_base_dataset.py
+++ b/tests/test_data/test_base_dataset.py
@@ -33,29 +33,28 @@ class CustomDataset(BaseDataset):
 
 
 class TestBaseDataset:
-    dataset_type = BaseDataset
-    data_info = dict(
-        filename='test_img.jpg', height=604, width=640, sample_idx=0)
-    imgs = torch.rand((2, 3, 32, 32))
-    pipeline = MagicMock(return_value=dict(imgs=imgs))
-    METAINFO: dict = dict()
-    parse_data_info = MagicMock(return_value=data_info)
-
-    def _init_dataset(self):
-        self.dataset_type.METAINFO = self.METAINFO
-        self.dataset_type.parse_data_info = self.parse_data_info
+
+    def setup(self):
+        self.data_info = dict(
+            filename='test_img.jpg', height=604, width=640, sample_idx=0)
+        self.imgs = torch.rand((2, 3, 32, 32))
+        self.ori_meta = BaseDataset.METAINFO
+        BaseDataset.parse_data_info = MagicMock(return_value=self.data_info)
+        self.pipeline = MagicMock(return_value=dict(imgs=self.imgs))
+
+    def teardown(self):
+        BaseDataset.METAINFO = self.ori_meta
 
     def test_init(self):
-        self._init_dataset()
         # test the instantiation of self.base_dataset
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
         assert dataset._fully_initialized
         assert hasattr(dataset, 'data_list')
         assert hasattr(dataset, 'data_address')
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json')
@@ -65,7 +64,7 @@ class TestBaseDataset:
 
         # test the instantiation of self.base_dataset with
         # `serialize_data=False`
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -77,7 +76,7 @@ class TestBaseDataset:
         assert dataset.get_data_info(0) == self.data_info
 
         # test the instantiation of self.base_dataset with lazy init
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -88,40 +87,40 @@ class TestBaseDataset:
         # test the instantiation of self.base_dataset if ann_file is not
         # existed.
         with pytest.raises(FileNotFoundError):
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/not_existed_annotation.json')
         # Use the default value of ann_file, i.e., ''
         with pytest.raises(FileNotFoundError):
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'))
 
         # test the instantiation of self.base_dataset when the ann_file is
         # wrong
         with pytest.raises(ValueError):
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/annotation_wrong_keys.json')
         with pytest.raises(TypeError):
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/annotation_wrong_format.json')
         with pytest.raises(TypeError):
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img=['img']),
                 ann_file='annotations/annotation_wrong_format.json')
 
         # test the instantiation of self.base_dataset when `parse_data_info`
         # return `list[dict]`
-        self.dataset_type.parse_data_info = MagicMock(
+        BaseDataset.parse_data_info = MagicMock(
             return_value=[self.data_info,
                           self.data_info.copy()])
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
@@ -136,24 +135,23 @@ class TestBaseDataset:
         # test the instantiation of self.base_dataset when `parse_data_info`
         # return unsupported data.
         with pytest.raises(TypeError):
-            self.dataset_type.parse_data_info = MagicMock(return_value='xxx')
-            dataset = self.dataset_type(
+            BaseDataset.parse_data_info = MagicMock(return_value='xxx')
+            dataset = BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/dummy_annotation.json')
         with pytest.raises(TypeError):
-            self.dataset_type.parse_data_info = MagicMock(
+            BaseDataset.parse_data_info = MagicMock(
                 return_value=[self.data_info, 'xxx'])
-            self.dataset_type(
+            BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/dummy_annotation.json')
 
     def test_meta(self):
-        self._init_dataset()
         # test dataset.metainfo with setting the metainfo from annotation file
         # as the metainfo of self.base_dataset.
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
@@ -163,10 +161,10 @@ class TestBaseDataset:
 
         # test dataset.metainfo with setting METAINFO in self.base_dataset
         dataset_type = 'new_dataset'
-        self.dataset_type.METAINFO = dict(
+        BaseDataset.METAINFO = dict(
             dataset_type=dataset_type, classes=('dog', 'cat'))
 
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
@@ -178,12 +176,12 @@ class TestBaseDataset:
 
         # test dataset.metainfo with passing metainfo into self.base_dataset
         metainfo = dict(classes=('dog', ), task_name='new_task')
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             metainfo=metainfo)
-        assert self.dataset_type.METAINFO == dict(
+        assert BaseDataset.METAINFO == dict(
             dataset_type=dataset_type, classes=('dog', 'cat'))
         assert dataset.metainfo == dict(
             dataset_type=dataset_type,
@@ -192,8 +190,8 @@ class TestBaseDataset:
             empty_list=[])
         # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
         # change
-        self.dataset_type.METAINFO['classes'] = ('dog', 'cat', 'fish')
-        assert self.dataset_type.METAINFO == dict(
+        BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')
+        assert BaseDataset.METAINFO == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
         assert dataset.metainfo == dict(
             dataset_type=dataset_type,
@@ -206,7 +204,7 @@ class TestBaseDataset:
         metainfo = dict(
             classes=osp.join(
                 osp.dirname(__file__), '../data/meta/classes.txt'))
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -221,7 +219,7 @@ class TestBaseDataset:
         # self.base_dataset
         with pytest.raises(TypeError):
             metainfo = 'dog'
-            dataset = self.dataset_type(
+            dataset = BaseDataset(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
                 ann_file='annotations/dummy_annotation.json',
@@ -230,7 +228,7 @@ class TestBaseDataset:
         # test dataset.metainfo with passing metainfo into self.base_dataset
         # and lazy_init is True
         metainfo = dict(classes=('dog', ))
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -243,26 +241,26 @@ class TestBaseDataset:
         # test whether self.base_dataset.METAINFO is changed when a customize
         # dataset inherit self.base_dataset
         # test reset METAINFO in ToyDataset.
-        class ToyDataset(self.dataset_type):
+        class ToyDataset(BaseDataset):
             METAINFO = dict(xxx='xxx')
 
         assert ToyDataset.METAINFO == dict(xxx='xxx')
-        assert self.dataset_type.METAINFO == dict(
+        assert BaseDataset.METAINFO == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
 
         # test update METAINFO in ToyDataset.
-        class ToyDataset(self.dataset_type):
-            METAINFO = copy.deepcopy(self.dataset_type.METAINFO)
+        class ToyDataset(BaseDataset):
+            METAINFO = copy.deepcopy(BaseDataset.METAINFO)
             METAINFO['classes'] = ('bird', )
 
         assert ToyDataset.METAINFO == dict(
             dataset_type=dataset_type, classes=('bird', ))
-        assert self.dataset_type.METAINFO == dict(
+        assert BaseDataset.METAINFO == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
 
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_length(self, lazy_init):
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -311,7 +309,7 @@ class TestBaseDataset:
 
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_getitem(self, lazy_init):
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -353,8 +351,7 @@ class TestBaseDataset:
 
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_get_data_info(self, lazy_init):
-        self._init_dataset()
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -386,8 +383,7 @@ class TestBaseDataset:
             class_without_full_init.foo()
 
     def test_full_init(self):
-        self._init_dataset()
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -404,7 +400,7 @@ class TestBaseDataset:
         assert dataset[0] == dict(imgs=self.imgs)
         assert dataset.get_data_info(0) == self.data_info
 
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -418,11 +414,11 @@ class TestBaseDataset:
         assert dataset.get_data_info(0) == self.data_info
 
         # test the instantiation of self.base_dataset when passing indices
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json')
-        dataset_sliced = self.dataset_type(
+        dataset_sliced = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json',
@@ -436,7 +432,7 @@ class TestBaseDataset:
     def test_get_subset_(self, lazy_init, serialize_data):
         # Test positive int indices.
         indices = 2
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json',
@@ -514,7 +510,7 @@ class TestBaseDataset:
     def test_get_subset(self, lazy_init, serialize_data):
         # Test positive indices.
         indices = 2
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json',
@@ -560,7 +556,7 @@ class TestBaseDataset:
 
     def test_rand_another(self):
         # test the instantiation of self.base_dataset when passing num_samples
-        dataset = self.dataset_type(
+        dataset = BaseDataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img=None),
             ann_file='annotations/dummy_annotation.json',
-- 
GitLab