Skip to content
Snippets Groups Projects
Unverified Commit 4dcbd269 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] register dataset wrapper (#185)

* register dataset wrapper

* fix as comment
parent 3c8806e4
No related branches found
No related tags found
No related merge requests found
......@@ -8,9 +8,11 @@ from typing import List, Sequence, Tuple, Union
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from mmengine.registry import DATASETS
from .base_dataset import BaseDataset, force_full_init
@DATASETS.register_module()
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
......@@ -24,19 +26,28 @@ class ConcatDataset(_ConcatDataset):
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args:
datasets (Sequence[BaseDataset]): A list of datasets which will be
concatenated.
datasets (Sequence[BaseDataset] or Sequence[dict]): A list of datasets
which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
"""
def __init__(self,
datasets: Sequence[BaseDataset],
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False):
self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
self.datasets.append(DATASETS.build(dataset))
elif isinstance(dataset, BaseDataset):
self.datasets.append(dataset)
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
# Only use metainfo of first dataset.
self._metainfo = datasets[0].metainfo
self.datasets = datasets # type: ignore
for i, dataset in enumerate(datasets, 1):
self._metainfo = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
if self._metainfo != dataset.metainfo:
raise ValueError(
f'The meta information of the {i}-th dataset does not '
......@@ -140,6 +151,7 @@ class ConcatDataset(_ConcatDataset):
'dataset first and then use `ConcatDataset`.')
@DATASETS.register_module()
class RepeatDataset:
"""A wrapper of repeated dataset.
......@@ -156,19 +168,27 @@ class RepeatDataset:
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args:
dataset (BaseDataset): The dataset to be repeated.
dataset (BaseDataset or dict): The dataset to be repeated.
times (int): Repeat times.
lazy_init (bool): Whether to load annotation during
instantiation. Defaults to False.
"""
def __init__(self,
dataset: BaseDataset,
dataset: Union[BaseDataset, dict],
times: int,
lazy_init: bool = False):
self.dataset = dataset
self.dataset: BaseDataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self.times = times
self._metainfo = dataset.metainfo
self._metainfo = self.dataset.metainfo
self._fully_initialized = False
if not lazy_init:
......@@ -283,7 +303,7 @@ class ClassBalancedDataset:
``BaseDataset``.
Args:
dataset (BaseDataset): The dataset to be repeated.
dataset (BaseDataset or dict): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c >= oversample_thr``, there is
no oversampling. For categories with ``f_c < oversample_thr``, the
......@@ -294,12 +314,19 @@ class ClassBalancedDataset:
"""
def __init__(self,
dataset: BaseDataset,
dataset: Union[BaseDataset, dict],
oversample_thr: float,
lazy_init: bool = False):
self.dataset = dataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
self.oversample_thr = oversample_thr
self._metainfo = dataset.metainfo
self._metainfo = self.dataset.metainfo
self._fully_initialized = False
if not lazy_init:
......
......@@ -8,7 +8,7 @@ import torch
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import TRANSFORMS
from mmengine.registry import DATASETS, TRANSFORMS
def function_pipeline(data_info):
......@@ -27,6 +27,11 @@ class NotCallableTransform:
pass
@DATASETS.register_module()
class CustomDataset(BaseDataset):
pass
class TestBaseDataset:
dataset_type = BaseDataset
data_info = dict(
......@@ -566,7 +571,7 @@ class TestBaseDataset:
class TestConcatDataset:
def _init_dataset(self):
def setup(self):
dataset = BaseDataset
# create dataset_a
......@@ -593,8 +598,25 @@ class TestConcatDataset:
self.cat_datasets = ConcatDataset(
datasets=[self.dataset_a, self.dataset_b])
def test_init(self):
# Test build dataset from cfg.
dataset_cfg_b = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b])
cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline
assert len(cat_datasets) == len(self.cat_datasets)
for i in range(len(cat_datasets)):
assert (cat_datasets.get_data_info(i) ==
self.cat_datasets.get_data_info(i))
assert (cat_datasets[i] == self.cat_datasets[i])
with pytest.raises(TypeError):
ConcatDataset(datasets=[0])
def test_full_init(self):
self._init_dataset()
# test init with lazy_init=True
self.cat_datasets.full_init()
assert len(self.cat_datasets) == 6
......@@ -618,16 +640,13 @@ class TestConcatDataset:
ConcatDataset(datasets=[self.dataset_a, dataset_b])
def test_metainfo(self):
self._init_dataset()
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
def test_length(self):
self._init_dataset()
assert len(self.cat_datasets) == (
len(self.dataset_a) + len(self.dataset_b))
def test_getitem(self):
self._init_dataset()
assert (
self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
assert (self.cat_datasets[0]['imgs'] !=
......@@ -639,7 +658,6 @@ class TestConcatDataset:
self.dataset_a[-1]['imgs']).all()
def test_get_data_info(self):
self._init_dataset()
assert self.cat_datasets.get_data_info(
0) == self.dataset_a.get_data_info(0)
assert self.cat_datasets.get_data_info(
......@@ -651,7 +669,6 @@ class TestConcatDataset:
-1) != self.dataset_a.get_data_info(-1)
def test_get_ori_dataset_idx(self):
self._init_dataset()
assert self.cat_datasets._get_ori_dataset_idx(3) == (
1, 3 - len(self.dataset_a))
assert self.cat_datasets._get_ori_dataset_idx(-1) == (
......@@ -662,7 +679,7 @@ class TestConcatDataset:
class TestRepeatDataset:
def _init_dataset(self):
def setup(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_data_info = MagicMock(return_value=data_info)
......@@ -678,9 +695,26 @@ class TestRepeatDataset:
self.repeat_datasets = RepeatDataset(
dataset=self.dataset, times=self.repeat_times)
def test_full_init(self):
self._init_dataset()
def test_init(self):
# Test build dataset from cfg.
dataset_cfg = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
repeat_dataset = RepeatDataset(
dataset=dataset_cfg, times=self.repeat_times)
repeat_dataset.dataset.pipeline = self.dataset.pipeline
assert len(repeat_dataset) == len(self.repeat_datasets)
for i in range(len(repeat_dataset)):
assert (repeat_dataset.get_data_info(i) ==
self.repeat_datasets.get_data_info(i))
assert (repeat_dataset[i] == self.repeat_datasets[i])
with pytest.raises(TypeError):
RepeatDataset(dataset=[0], times=5)
def test_full_init(self):
self.repeat_datasets.full_init()
assert len(
self.repeat_datasets) == self.repeat_times * len(self.dataset)
......@@ -697,22 +731,18 @@ class TestRepeatDataset:
self.repeat_datasets.get_subset(1)
def test_metainfo(self):
self._init_dataset()
assert self.repeat_datasets.metainfo == self.dataset.metainfo
def test_length(self):
self._init_dataset()
assert len(
self.repeat_datasets) == len(self.dataset) * self.repeat_times
def test_getitem(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets[len(self.dataset) *
i] == self.dataset[0]
def test_get_data_info(self):
self._init_dataset()
for i in range(self.repeat_times):
assert self.repeat_datasets.get_data_info(
len(self.dataset) * i) == self.dataset.get_data_info(0)
......@@ -720,7 +750,7 @@ class TestRepeatDataset:
class TestClassBalancedDataset:
def _init_dataset(self):
def setup(self):
dataset = BaseDataset
data_info = dict(filename='test_img.jpg', height=604, width=640)
dataset.parse_data_info = MagicMock(return_value=data_info)
......@@ -738,17 +768,35 @@ class TestClassBalancedDataset:
dataset=self.dataset, oversample_thr=1e-3)
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
def test_full_init(self):
self._init_dataset()
def test_init(self):
# Test build dataset from cfg.
dataset_cfg = dict(
type=CustomDataset,
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
cls_banlanced_datasets = ClassBalancedDataset(
dataset=dataset_cfg, oversample_thr=1e-3)
cls_banlanced_datasets.repeat_indices = self.repeat_indices
cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline
assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets)
for i in range(len(cls_banlanced_datasets)):
assert (cls_banlanced_datasets.get_data_info(i) ==
self.cls_banlanced_datasets.get_data_info(i))
assert (
cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i])
with pytest.raises(TypeError):
ClassBalancedDataset(dataset=[0], times=5)
def test_full_init(self):
self.cls_banlanced_datasets.full_init()
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
self.cls_banlanced_datasets.full_init()
# Reinit `repeat_indices`.
self.cls_banlanced_datasets._fully_initialized = False
self.cls_banlanced_datasets[1]
self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
assert len(self.cls_banlanced_datasets) != len(self.repeat_indices)
with pytest.raises(NotImplementedError):
self.cls_banlanced_datasets.get_subset_(1)
......@@ -757,27 +805,22 @@ class TestClassBalancedDataset:
self.cls_banlanced_datasets.get_subset(1)
def test_metainfo(self):
self._init_dataset()
assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo
def test_length(self):
self._init_dataset()
assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
def test_getitem(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets[i] == self.dataset[
self.repeat_indices[i]]
def test_get_data_info(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_data_info(
i) == self.dataset.get_data_info(self.repeat_indices[i])
def test_get_cat_ids(self):
self._init_dataset()
for i in range(len(self.repeat_indices)):
assert self.cls_banlanced_datasets.get_cat_ids(
i) == self.dataset.get_cat_ids(self.repeat_indices[i])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment