From 4dcbd269aa85e7af6d1d6857de731e9055f47102 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 26 Apr 2022 16:52:13 +0800 Subject: [PATCH] [Enhance] register dataset wrapper (#185) * register dataset wrapper * fix as comment --- mmengine/dataset/dataset_wrapper.py | 55 ++++++++++++---- tests/test_data/test_base_dataset.py | 95 ++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 40 deletions(-) diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index 454ec96d..2cf995e7 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -8,9 +8,11 @@ from typing import List, Sequence, Tuple, Union from torch.utils.data.dataset import ConcatDataset as _ConcatDataset +from mmengine.registry import DATASETS from .base_dataset import BaseDataset, force_full_init +@DATASETS.register_module() class ConcatDataset(_ConcatDataset): """A wrapper of concatenated dataset. @@ -24,19 +26,28 @@ class ConcatDataset(_ConcatDataset): arguments for wrapped dataset which inherit from ``BaseDataset``. Args: - datasets (Sequence[BaseDataset]): A list of datasets which will be - concatenated. + datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets + which will be concatenated. lazy_init (bool, optional): Whether to load annotation during instantiation. Defaults to False. """ def __init__(self, - datasets: Sequence[BaseDataset], + datasets: Sequence[Union[BaseDataset, dict]], lazy_init: bool = False): + self.datasets: List[BaseDataset] = [] + for i, dataset in enumerate(datasets): + if isinstance(dataset, dict): + self.datasets.append(DATASETS.build(dataset)) + elif isinstance(dataset, BaseDataset): + self.datasets.append(dataset) + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') # Only use metainfo of first dataset. - self._metainfo = datasets[0].metainfo - self.datasets = datasets # type: ignore - for i, dataset in enumerate(datasets, 1): + self._metainfo = self.datasets[0].metainfo + for i, dataset in enumerate(self.datasets, 1): if self._metainfo != dataset.metainfo: raise ValueError( f'The meta information of the {i}-th dataset does not ' @@ -140,6 +151,7 @@ class ConcatDataset(_ConcatDataset): 'dataset first and then use `ConcatDataset`.') +@DATASETS.register_module() class RepeatDataset: """A wrapper of repeated dataset. @@ -156,19 +168,27 @@ class RepeatDataset: arguments for wrapped dataset which inherit from ``BaseDataset``. Args: - dataset (BaseDataset): The dataset to be repeated. + dataset (BaseDataset or dict): The dataset to be repeated. times (int): Repeat times. lazy_init (bool): Whether to load annotation during instantiation. Defaults to False. """ def __init__(self, - dataset: BaseDataset, + dataset: Union[BaseDataset, dict], times: int, lazy_init: bool = False): - self.dataset = dataset + self.dataset: BaseDataset + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') self.times = times - self._metainfo = dataset.metainfo + self._metainfo = self.dataset.metainfo self._fully_initialized = False if not lazy_init: @@ -283,7 +303,7 @@ class ClassBalancedDataset: ``BaseDataset``. Args: - dataset (BaseDataset): The dataset to be repeated. + dataset (BaseDataset or dict): The dataset to be repeated. oversample_thr (float): frequency threshold below which data is repeated. For categories with ``f_c >= oversample_thr``, there is no oversampling. For categories with ``f_c < oversample_thr``, the @@ -294,12 +314,19 @@ class ClassBalancedDataset: """ def __init__(self, - dataset: BaseDataset, + dataset: Union[BaseDataset, dict], oversample_thr: float, lazy_init: bool = False): - self.dataset = dataset + if isinstance(dataset, dict): + self.dataset = DATASETS.build(dataset) + elif isinstance(dataset, BaseDataset): + self.dataset = dataset + else: + raise TypeError( + 'elements in datasets sequence should be config or ' + f'`BaseDataset` instance, but got {type(dataset)}') self.oversample_thr = oversample_thr - self._metainfo = dataset.metainfo + self._metainfo = self.dataset.metainfo self._fully_initialized = False if not lazy_init: diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index 2488144f..f82be769 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -8,7 +8,7 @@ import torch from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose, ConcatDataset, RepeatDataset, force_full_init) -from mmengine.registry import TRANSFORMS +from mmengine.registry import DATASETS, TRANSFORMS def function_pipeline(data_info): @@ -27,6 +27,11 @@ class NotCallableTransform: pass +@DATASETS.register_module() +class CustomDataset(BaseDataset): + pass + + class TestBaseDataset: dataset_type = BaseDataset data_info = dict( @@ -566,7 +571,7 @@ class TestBaseDataset: class TestConcatDataset: - def _init_dataset(self): + def setup(self): dataset = BaseDataset # create dataset_a @@ -593,8 +598,25 @@ class TestConcatDataset: self.cat_datasets = ConcatDataset( datasets=[self.dataset_a, self.dataset_b]) + def test_init(self): + # Test build dataset from cfg. + dataset_cfg_b = dict( + type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b]) + cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline + assert len(cat_datasets) == len(self.cat_datasets) + for i in range(len(cat_datasets)): + assert (cat_datasets.get_data_info(i) == + self.cat_datasets.get_data_info(i)) + assert (cat_datasets[i] == self.cat_datasets[i]) + + with pytest.raises(TypeError): + ConcatDataset(datasets=[0]) + def test_full_init(self): - self._init_dataset() # test init with lazy_init=True self.cat_datasets.full_init() assert len(self.cat_datasets) == 6 @@ -618,16 +640,13 @@ class TestConcatDataset: ConcatDataset(datasets=[self.dataset_a, dataset_b]) def test_metainfo(self): - self._init_dataset() assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): - self._init_dataset() assert len(self.cat_datasets) == ( len(self.dataset_a) + len(self.dataset_b)) def test_getitem(self): - self._init_dataset() assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() assert (self.cat_datasets[0]['imgs'] != @@ -639,7 +658,6 @@ class TestConcatDataset: self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): - self._init_dataset() assert self.cat_datasets.get_data_info( 0) == self.dataset_a.get_data_info(0) assert self.cat_datasets.get_data_info( @@ -651,7 +669,6 @@ class TestConcatDataset: -1) != self.dataset_a.get_data_info(-1) def test_get_ori_dataset_idx(self): - self._init_dataset() assert self.cat_datasets._get_ori_dataset_idx(3) == ( 1, 3 - len(self.dataset_a)) assert self.cat_datasets._get_ori_dataset_idx(-1) == ( @@ -662,7 +679,7 @@ class TestConcatDataset: class TestRepeatDataset: - def _init_dataset(self): + def setup(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) @@ -678,9 +695,26 @@ class TestRepeatDataset: self.repeat_datasets = RepeatDataset( dataset=self.dataset, times=self.repeat_times) - def test_full_init(self): - self._init_dataset() + def test_init(self): + # Test build dataset from cfg. + dataset_cfg = dict( + type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + repeat_dataset = RepeatDataset( + dataset=dataset_cfg, times=self.repeat_times) + repeat_dataset.dataset.pipeline = self.dataset.pipeline + assert len(repeat_dataset) == len(self.repeat_datasets) + for i in range(len(repeat_dataset)): + assert (repeat_dataset.get_data_info(i) == + self.repeat_datasets.get_data_info(i)) + assert (repeat_dataset[i] == self.repeat_datasets[i]) + with pytest.raises(TypeError): + RepeatDataset(dataset=[0], times=5) + + def test_full_init(self): self.repeat_datasets.full_init() assert len( self.repeat_datasets) == self.repeat_times * len(self.dataset) @@ -697,22 +731,18 @@ class TestRepeatDataset: self.repeat_datasets.get_subset(1) def test_metainfo(self): - self._init_dataset() assert self.repeat_datasets.metainfo == self.dataset.metainfo def test_length(self): - self._init_dataset() assert len( self.repeat_datasets) == len(self.dataset) * self.repeat_times def test_getitem(self): - self._init_dataset() for i in range(self.repeat_times): assert self.repeat_datasets[len(self.dataset) * i] == self.dataset[0] def test_get_data_info(self): - self._init_dataset() for i in range(self.repeat_times): assert self.repeat_datasets.get_data_info( len(self.dataset) * i) == self.dataset.get_data_info(0) @@ -720,7 +750,7 @@ class TestRepeatDataset: class TestClassBalancedDataset: - def _init_dataset(self): + def setup(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) @@ -738,17 +768,35 @@ class TestClassBalancedDataset: dataset=self.dataset, oversample_thr=1e-3) self.cls_banlanced_datasets.repeat_indices = self.repeat_indices - def test_full_init(self): - self._init_dataset() + def test_init(self): + # Test build dataset from cfg. + dataset_cfg = dict( + type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + cls_banlanced_datasets = ClassBalancedDataset( + dataset=dataset_cfg, oversample_thr=1e-3) + cls_banlanced_datasets.repeat_indices = self.repeat_indices + cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline + assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets) + for i in range(len(cls_banlanced_datasets)): + assert (cls_banlanced_datasets.get_data_info(i) == + self.cls_banlanced_datasets.get_data_info(i)) + assert ( + cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i]) + + with pytest.raises(TypeError): + ClassBalancedDataset(dataset=[0], times=5) + def test_full_init(self): self.cls_banlanced_datasets.full_init() self.cls_banlanced_datasets.repeat_indices = self.repeat_indices assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) - self.cls_banlanced_datasets.full_init() + # Reinit `repeat_indices`. self.cls_banlanced_datasets._fully_initialized = False - self.cls_banlanced_datasets[1] self.cls_banlanced_datasets.repeat_indices = self.repeat_indices - assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) + assert len(self.cls_banlanced_datasets) != len(self.repeat_indices) with pytest.raises(NotImplementedError): self.cls_banlanced_datasets.get_subset_(1) @@ -757,27 +805,22 @@ class TestClassBalancedDataset: self.cls_banlanced_datasets.get_subset(1) def test_metainfo(self): - self._init_dataset() assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo def test_length(self): - self._init_dataset() assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) def test_getitem(self): - self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets[i] == self.dataset[ self.repeat_indices[i]] def test_get_data_info(self): - self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets.get_data_info( i) == self.dataset.get_data_info(self.repeat_indices[i]) def test_get_cat_ids(self): - self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets.get_cat_ids( i) == self.dataset.get_cat_ids(self.repeat_indices[i]) -- GitLab