From ada6660c65be862594d2f3dd0991229b7f619d50 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Tue, 22 Feb 2022 14:01:06 +0800
Subject: [PATCH] [Feature] add base dataset (#32)

* basedataset first commit

* add base dataset

* add dataset

* add basedataset

* Fix test dataset

* Fix mypy and test

* Fix mypy and test

* remove unused code

* Update mmengine/dataset/base_dataset.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/dataset/base_dataset.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add more corner cases in unittest

* fix lint

* Fix as comment

* Fix lint

* update unitest

* Type hint Dick to dict

* rename max_refetch

* Fix as comment

* Fix typo

* Fix as comment

* BaseDataset is no more an abstrac Class, change UT and docs

* Fix as comment

* Fix as comment and refactor type error

* Add comment for full init

* Fix as comment and modify dataset_wrapper

* Fix as comment and modify dataset_wrapper

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Tao Gong <gongtao950513@gmail.com>
---
 docs/zh_cn/tutorials/basedataset.md           |   6 +-
 mmengine/__init__.py                          |   1 +
 mmengine/dataset/__init__.py                  |   4 +
 mmengine/dataset/base_dataset.py              | 553 ++++++++++++++++++
 mmengine/dataset/dataset_wrapper.py           | 364 ++++++++++++
 .../annotations/annotation_wrong_format.json  |   5 +
 ...tation.json => annotation_wrong_keys.json} |   0
 tests/data/annotations/dummy_annotation.json  |   3 +-
 tests/data/meta/classes.txt                   |   1 +
 tests/test_data/test_base_dataset.py          | 404 ++++++++++---
 10 files changed, 1268 insertions(+), 73 deletions(-)
 create mode 100644 mmengine/dataset/__init__.py
 create mode 100644 mmengine/dataset/base_dataset.py
 create mode 100644 mmengine/dataset/dataset_wrapper.py
 create mode 100644 tests/data/annotations/annotation_wrong_format.json
 rename tests/data/annotations/{wrong_annotation.json => annotation_wrong_keys.json} (100%)
 create mode 100644 tests/data/meta/classes.txt

diff --git a/docs/zh_cn/tutorials/basedataset.md b/docs/zh_cn/tutorials/basedataset.md
index 08a49af9..e941a733 100644
--- a/docs/zh_cn/tutorials/basedataset.md
+++ b/docs/zh_cn/tutorials/basedataset.md
@@ -69,7 +69,7 @@ data
 
 2. 构建数据流水线(data pipeline),用于数据预处理与数据准备;
 
-3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 抽象方法,该抽象方法负责解析标注文件里的每个原始数据;
+3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 方法,该方法负责解析标注文件里的每个原始数据;
 
 4. 过滤无用数据,比如不包含标注的样本等;
 
@@ -77,8 +77,6 @@ data
 
 6. 序列化全部样本,以达到节省内存的效果,详情请参考[节省内存](#节省内存)。
 
-数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()` ,`parse_annotations()` 定义了将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户必须要实现 `parse_annotations()` 方法。
-
 ### 数据集基类提供的接口
 
 与 `torch.utils.data.Dataset` 类似,数据集初始化后,支持 `__getitem__` 方法,用来索引数据,以及 `__len__` 操作获取数据集大小,除此之外,OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:
@@ -93,7 +91,7 @@ data
 
 ## 使用数据集基类自定义数据集类
 
-在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,数据集基类是一个抽象类,它有且只有一个抽象方法 `parse_annotations()`,因此用户必须在自定义数据集类中实现该方法。以下是一个使用数据集基类来实现某一具体数据集的例子。
+在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 `parse_annotations()`来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。
 
 ```python
 import os.path as osp
diff --git a/mmengine/__init__.py b/mmengine/__init__.py
index 34a3acc1..d389ac84 100644
--- a/mmengine/__init__.py
+++ b/mmengine/__init__.py
@@ -1,6 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 # flake8: noqa
 from .config import *
+from .dataset import *
 from .fileio import *
 from .registry import *
 from .utils import *
diff --git a/mmengine/dataset/__init__.py b/mmengine/dataset/__init__.py
new file mode 100644
index 00000000..25ee724b
--- /dev/null
+++ b/mmengine/dataset/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .base_dataset import BaseDataset, Compose, force_full_init
+from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py
new file mode 100644
index 00000000..586292eb
--- /dev/null
+++ b/mmengine/dataset/base_dataset.py
@@ -0,0 +1,553 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import functools
+import gc
+import os.path as osp
+import pickle
+import warnings
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from mmengine.fileio import list_from_file, load
+from mmengine.registry import TRANSFORMS
+from mmengine.utils import check_file_exist
+
+
+class Compose:
+    """Compose multiple transforms sequentially.
+
+    Args:
+        transforms (Sequence[dict, callable]): Sequence of transform object or
+            config dict to be composed.
+    """
+
+    def __init__(self, transforms: Sequence[Union[dict, Callable]]):
+        self.transforms: List[Callable] = []
+        for transform in transforms:
+            if isinstance(transform, dict):
+                transform = TRANSFORMS.build(transform)
+                if not callable(transform):
+                    raise TypeError(f'transform should be a callable object, '
+                                    f'but got {type(transform)}')
+                self.transforms.append(transform)
+            elif callable(transform):
+                self.transforms.append(transform)
+            else:
+                raise TypeError(
+                    f'transform must be a callable object or dict, '
+                    f'but got {type(transform)}')
+
+    def __call__(self, data: dict) -> Optional[dict]:
+        """Call function to apply transforms sequentially.
+
+        Args:
+            data (dict): A result dict contains the data to transform.
+
+        Returns:
+           dict: Transformed data.
+        """
+        for t in self.transforms:
+            data = t(data)
+            if data is None:
+                return None
+        return data
+
+    def __repr__(self):
+        """Print ``self.transforms`` in sequence.
+
+        Returns:
+            str: Formatted string.
+        """
+        format_string = self.__class__.__name__ + '('
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += f'    {t}'
+        format_string += '\n)'
+        return format_string
+
+
+def force_full_init(old_func: Callable) -> Any:
+    """Those methods decorated by ``force_full_init`` will be forced to call
+    ``full_init`` if the instance has not been fully initiated.
+
+    Args:
+        old_func (Callable): Decorated function, make sure the first arg is an
+            instance with ``full_init`` method.
+
+    Returns:
+        Any: Depends on old_func.
+    """
+    # TODO This decorator can also be used in runner.
+    @functools.wraps(old_func)
+    def wrapper(obj: object, *args, **kwargs):
+        if not hasattr(obj, 'full_init'):
+            raise AttributeError(f'{type(obj)} does not have full_init '
+                                 'method.')
+        if not getattr(obj, '_fully_initialized', False):
+            warnings.warn('Attribute `_fully_initialized` is not defined in '
+                          f'{type(obj)} or `type(obj)._fully_initialized is '
+                          'False, `full_init` will be called and '
+                          f'{type(obj)}._fully_initialized will be set to '
+                          'True')
+            obj.full_init()  # type: ignore
+            obj._fully_initialized = True  # type: ignore
+
+        return old_func(obj, *args, **kwargs)
+
+    return wrapper
+
+
+class BaseDataset(Dataset):
+    r"""BaseDataset for open source projects in OpenMMLab.
+
+    The annotation format is shown as follows.
+
+    .. code-block:: none
+
+        {
+            "metadata":
+            {
+              "dataset_type": "test_dataset",
+              "task_name": "test_task"
+            },
+            "data_infos":
+            [
+              {
+                "img_path": "test_img.jpg",
+                "height": 604,
+                "width": 640,
+                "instances":
+                [
+                  {
+                    "bbox": [0, 0, 10, 20],
+                    "bbox_label": 1,
+                    "mask": [[0,0],[0,10],[10,20],[20,0]],
+                    "extra_anns": [1,2,3]
+                  },
+                  {
+                    "bbox": [10, 10, 110, 120],
+                    "bbox_label": 2,
+                    "mask": [[10,10],[10,110],[110,120],[120,10]],
+                    "extra_anns": [4,5,6]
+                  }
+                ]
+              },
+            ]
+        }
+
+    Args:
+        ann_file (str): Annotation file path.
+        meta (dict, optional): Meta information for dataset, such as class
+            information. Defaults to None.
+        data_root (str, optional): The root directory for ``data_prefix`` and
+            ``ann_file``. Defaults to None.
+        data_prefix (dict, optional): Prefix for training data. Defaults to
+            dict(img=None, ann=None).
+        filter_cfg (dict, optional): Config for filter data. Defaults to None.
+        num_samples (int, optional): Support using first few data in
+            annotation file to facilitate training/testing on a smaller
+            dataset. Defaults to -1 which means using all ``data_infos``.
+        serialize_data (bool, optional): Whether to hold memory using
+            serialized objects, when enabled, data loader workers can use
+            shared RAM from master process instead of making a copy. Defaults
+            to True.
+        pipeline (list, optional): Processing pipeline. Defaults to [].
+        test_mode (bool, optional): ``test_mode=True`` means in test phase.
+            Defaults to False.
+        lazy_init (bool, optional): Whether to load annotation during
+            instantiation. In some cases, such as visualization, only the meta
+            information of the dataset is needed, which is not necessary to
+            load annotation file. ``Basedataset`` can skip load annotations to
+            save time by set ``lazy_init=False``. Defaults to False.
+        max_refetch (int, optional): The maximum number of cycles to get a
+            valid image. Defaults to 1000.
+
+    Note:
+        BaseDataset collects meta information from `annotation file` (the
+        lowest priority), ``BaseDataset.META``(medium) and `meta parameter`
+        (highest) passed to constructors. The lower priority meta information
+        will be overwritten by higher one.
+    """
+
+    META: dict = dict()
+    _fully_initialized: bool = False
+
+    def __init__(self,
+                 ann_file: str,
+                 meta: Optional[dict] = None,
+                 data_root: Optional[str] = None,
+                 data_prefix: dict = dict(img=None, ann=None),
+                 filter_cfg: Optional[dict] = None,
+                 num_samples: int = -1,
+                 serialize_data: bool = True,
+                 pipeline: List[Union[dict, Callable]] = [],
+                 test_mode: bool = False,
+                 lazy_init: bool = False,
+                 max_refetch: int = 1000):
+
+        self.data_root = data_root
+        self.data_prefix = copy.copy(data_prefix)
+        self.ann_file = ann_file
+        self.filter_cfg = copy.deepcopy(filter_cfg)
+        self._num_samples = num_samples
+        self.serialize_data = serialize_data
+        self.test_mode = test_mode
+        self.max_refetch = max_refetch
+        self.data_infos: List[dict] = []
+        self.data_infos_bytes: bytearray = bytearray()
+
+        # set meta information
+        self._meta = self._get_meta_data(copy.deepcopy(meta))
+
+        # join paths
+        if self.data_root is not None:
+            self._join_prefix()
+
+        # build pipeline
+        self.pipeline = Compose(pipeline)
+
+        if not lazy_init:
+            self.full_init()
+
+    @force_full_init
+    def get_data_info(self, idx: int) -> dict:
+        """Get annotation by index and automatically call ``full_init`` if the
+        dataset has not been fully initialized.
+
+        Args:
+            idx (int): The index of data.
+
+        Returns:
+            dict: The idx-th annotation of the dataset.
+        """
+        if self.serialize_data:
+            start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
+            end_addr = self.data_address[idx].item()
+            bytes = memoryview(self.data_infos_bytes[start_addr:end_addr])
+            data_info = pickle.loads(bytes)
+        else:
+            data_info = self.data_infos[idx]
+        # To record the real positive index of data information.
+        if idx >= 0:
+            data_info['sample_idx'] = idx
+        else:
+            data_info['sample_idx'] = len(self) + idx
+
+        return data_info
+
+    def full_init(self):
+        """Load annotation file and set ``BaseDataset._fully_initialized`` to
+        True.
+
+        If ``lazy_init=False``, ``full_init`` will be called during the
+        instantiation and ``self._fully_initialized`` will be set to True. If
+        ``obj._fully_initialized=False``, the class method decorated by
+        ``force_full_init`` will call ``full_init`` automatically.
+
+        Several steps to initialize annotation:
+
+            - load_annotations: Load annotations from annotation file.
+            - filter data information: Filter annotations according to
+            filter_cfg.
+            - slice_data: Slice dataset according to ``self._num_samples``
+            - serialize_data: Serialize ``self.data_infos`` if
+            ``self.serialize_data`` is True.
+        """
+        if self._fully_initialized:
+            return
+
+        # load data information
+        self.data_infos = self.load_annotations(self.ann_file)
+        # filter illegal data, such as data that has no annotations.
+        self.data_infos = self.filter_data()
+        # if `num_samples > 0`, return the first `num_samples` data information
+        self.data_infos = self._slice_data()
+        # serialize data_infos
+        if self.serialize_data:
+            self.data_infos_bytes, self.data_address = self._serialize_data()
+            # Empty cache for preventing making multiple copies of
+            # `self.data_info` when loading data multi-processes.
+            self.data_infos.clear()
+            gc.collect()
+        self._fully_initialized = True
+
+    @property
+    def meta(self) -> dict:
+        """Get meta information of dataset.
+
+        Returns:
+            dict: meta information collected from ``BaseDataset.META``,
+            annotation file and meta parameter during instantiation.
+        """
+        return copy.deepcopy(self._meta)
+
+    def parse_annotations(self,
+                          raw_data_info: dict) -> Union[dict, List[dict]]:
+        """Parse raw annotation to target format.
+
+        ``parse_annotations`` should return ``dict`` or ``List[dict]``. Each
+        dict contains the annotations of a training sample. If the protocol of
+        the sample annotations is changed, this function can be overridden to
+        update the parsing logic while keeping compatibility.
+
+        Args:
+            raw_data_info (dict): Raw annotation load from ``ann_file``
+
+        Returns:
+            Union[dict, List[dict]]: Parsed annotation.
+        """
+        return raw_data_info
+
+    def filter_data(self) -> List[dict]:
+        """Filter annotations according to filter_cfg. Defaults return all
+        ``data_infos``.
+
+        If some ``data_infos`` could be filtered according to specific logic,
+        the subclass should override this method.
+
+        Returns:
+            List[dict]: Filtered results.
+        """
+        return self.data_infos
+
+    def get_cat_ids(self, idx: int) -> List[int]:
+        """Get category ids by index. Dataset wrapped by ClassBalancedDataset
+        must implement this method.
+
+        The ``ClassBalancedDataset`` requires a subclass which implements this
+        method.
+
+        Args:
+            idx (int): The index of data.
+
+        Returns:
+            List[int]: All categories in the image of specified index.
+        """
+        raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` '
+                                  'method')
+
+    def __getitem__(self, idx: int) -> dict:
+        """Get the idx-th image of dataset after ``self.pipelines`` and
+        ``full_init`` will be called if the  dataset has not been fully
+        initialized.
+
+        During training phase, if ``self.pipelines`` get ``None``,
+        ``self._rand_another`` will be called until a valid image is fetched or
+         the maximum limit of refetech is reached.
+
+        Args:
+            idx (int): The index of self.data_infos
+
+        Returns:
+            dict: The idx-th image of dataset after ``self.pipelines``.
+        """
+        if not self._fully_initialized:
+            warnings.warn(
+                'Please call `full_init()` method manually to accelerate '
+                'the speed.')
+            self.full_init()
+
+        if self.test_mode:
+            return self._prepare_data(idx)
+
+        for _ in range(self.max_refetch):
+            data_sample = self._prepare_data(idx)
+            if data_sample is None:
+                idx = self._rand_another()
+                continue
+            return data_sample
+
+        raise Exception(f'Cannot find valid image after {self.max_refetch}! '
+                        'Please check your image path and pipelines')
+
+    def load_annotations(self, ann_file: str) -> List[dict]:
+        """Load annotations from an annotation file.
+
+        If the annotation file does not follow `OpenMMLab 2.0 format dataset
+        <https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ .
+        The subclass must override this method for load annotations.
+
+        Args:
+            ann_file (str): Absolute annotation file path if ``self.root=None``
+                or relative path if ``self.root=/path/to/data/``.
+
+        Returns:
+            List[dict]: A list of annotation.
+        """  # noqa: E501
+        check_file_exist(ann_file)
+        annotations = load(ann_file)
+        if not isinstance(annotations, dict):
+            raise TypeError(f'The annotations loaded from annotation file '
+                            f'should be a dict, but got {type(annotations)}!')
+        if 'data_infos' not in annotations or 'metadata' not in annotations:
+            raise ValueError('Annotation must have data_infos and metadata '
+                             'keys')
+        meta_data = annotations['metadata']
+        raw_data_infos = annotations['data_infos']
+
+        # update self._meta
+        for k, v in meta_data.items():
+            # We only merge keys that are not contained in self._meta.
+            self._meta.setdefault(k, v)
+
+        # load and parse data_infos
+        data_infos = []
+        for raw_data_info in raw_data_infos:
+            data_info = self.parse_annotations(raw_data_info)
+            if isinstance(data_info, dict):
+                data_infos.append(data_info)
+            elif isinstance(data_info, list):
+                # data_info can also be a list of dict, which means one
+                # data_info contains multiple samples.
+                for item in data_info:
+                    if not isinstance(item, dict):
+                        raise TypeError('data_info must be a list[dict], but '
+                                        f'got {type(item)}')
+                data_infos.extend(data_info)
+            else:
+                raise TypeError('data_info should be a dict or list[dict], '
+                                f'but got {type(data_info)}')
+
+        return data_infos
+
+    @classmethod
+    def _get_meta_data(cls, in_meta: dict = None) -> dict:
+        """Collect meta information from the dictionary of meta.
+
+        Args:
+            in_meta (dict): Meta information dict. If ``in_meta`` contains
+                existed filename, it will be parsed by ``list_from_file``.
+
+        Returns:
+            dict: Parsed meta information.
+        """
+        # cls.META will be overwritten by in_meta
+        cls_meta = copy.deepcopy(cls.META)
+        if in_meta is None:
+            return cls_meta
+        if not isinstance(in_meta, dict):
+            raise TypeError(
+                f'in_meta should be a dict, but got {type(in_meta)}')
+
+        for k, v in in_meta.items():
+            if isinstance(v, str) and osp.isfile(v):
+                # if filename in in_meta, this key will be further parsed.
+                # nested filename will be ignored.:
+                cls_meta[k] = list_from_file(v)
+            else:
+                cls_meta[k] = v
+
+        return cls_meta
+
+    def _join_prefix(self):
+        """Join ``self.data_root`` with ``self.data_prefix`` and
+        ``self.ann_file``.
+
+        Examples:
+            >>> # self.data_prefix contains relative paths
+            >>> self.data_root = 'a/b/c'
+            >>> self.data_prefix = dict(img='d/e/')
+            >>> self.ann_file = 'f'
+            >>> self._join_prefix()
+            >>> self.data_prefix
+            dict(img='a/b/c/d/e')
+            >>> self.ann_file
+            'a/b/c/f'
+            >>> # self.data_prefix contains absolute paths
+            >>> self.data_root = 'a/b/c'
+            >>> self.data_prefix = dict(img='/d/e/')
+            >>> self.ann_file = 'f'
+            >>> self._join_prefix()
+            >>> self.data_prefix
+            dict(img='/d/e')
+            >>> self.ann_file
+            'a/b/c/f'
+        """
+        if not osp.isabs(self.ann_file):
+            self.ann_file = osp.join(self.data_root, self.ann_file)
+
+        for data_key, prefix in self.data_prefix.items():
+            if prefix is None:
+                self.data_prefix[data_key] = self.data_root
+            elif isinstance(prefix, str):
+                if not osp.isabs(prefix):
+                    self.data_prefix[data_key] = osp.join(
+                        self.data_root, prefix)
+            else:
+                raise TypeError('prefix should be a string or None, but got '
+                                f'{type(prefix)}')
+
+    def _slice_data(self) -> List[dict]:
+        """Slice ``self.data_infos``. BaseDataset supports only using the first
+        few data.
+
+        Returns:
+            List[dict]: A slice of ``self.data_infos``
+        """
+        assert self._num_samples < len(self.data_infos), \
+            f'Slice size({self._num_samples}) is larger than dataset ' \
+            f'size({self.data_infos}, please keep `num_sample` smaller than' \
+            f'{self.data_infos})'
+        if self._num_samples > 0:
+            return self.data_infos[:self._num_samples]
+        else:
+            return self.data_infos
+
+    def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]:
+        """Serialize ``self.data_infos`` to save memory when launching multiple
+        workers in data loading. This function will be called in ``full_init``.
+
+        Hold memory using serialized objects, and data loader workers can use
+        shared RAM from master process instead of making a copy.
+
+        Returns:
+            Tuple[np.ndarray, np.ndarray]: serialize result and corresponding
+            address.
+        """
+
+        def _serialize(data):
+            buffer = pickle.dumps(data, protocol=4)
+            return np.frombuffer(buffer, dtype=np.uint8)
+
+        serialized_data_infos_list = [_serialize(x) for x in self.data_infos]
+        address_list = np.asarray([len(x) for x in serialized_data_infos_list],
+                                  dtype=np.int64)
+        data_address: np.ndarray = np.cumsum(address_list)
+        serialized_data_infos = np.concatenate(serialized_data_infos_list)
+
+        return serialized_data_infos, data_address
+
+    def _rand_another(self) -> int:
+        """Get random index.
+
+        Returns:
+            int: Random index from 0 to ``len(self)-1``
+        """
+        return np.random.randint(0, len(self))
+
+    def _prepare_data(self, idx) -> Any:
+        """Get data processed by ``self.pipeline``.
+
+        Args:
+            idx (int): The index of ``data_info``.
+
+        Returns:
+            Any: Depends on ``self.pipeline``.
+        """
+        data_info = self.get_data_info(idx)
+        return self.pipeline(data_info)
+
+    @force_full_init
+    def __len__(self) -> int:
+        """Get the length of filtered dataset and automatically call
+        ``full_init`` if the  dataset has not been fully init.
+
+        Returns:
+            int: The length of filtered dataset.
+        """
+        if self.serialize_data:
+            return len(self.data_address)
+        else:
+            return len(self.data_infos)
diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py
new file mode 100644
index 00000000..ba8c83cf
--- /dev/null
+++ b/mmengine/dataset/dataset_wrapper.py
@@ -0,0 +1,364 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import bisect
+import copy
+import math
+import warnings
+from collections import defaultdict
+from typing import List, Sequence, Tuple
+
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .base_dataset import BaseDataset, force_full_init
+
+
+class ConcatDataset(_ConcatDataset):
+    """A wrapper of concatenated dataset.
+
+    Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init.
+
+    Args:
+        datasets (Sequence[BaseDataset]): A list of datasets which will be
+            concatenated.
+        lazy_init (bool, optional): Whether to load annotation during
+            instantiation. Defaults to False.
+    """
+
+    def __init__(self,
+                 datasets: Sequence[BaseDataset],
+                 lazy_init: bool = False):
+        # Only use meta of first dataset.
+        self._meta = datasets[0].meta
+        self.datasets = datasets  # type: ignore
+        for i, dataset in enumerate(datasets, 1):
+            if self._meta != dataset.meta:
+                warnings.warn(
+                    f'The meta information of the {i}-th dataset does not '
+                    'match meta information of the first dataset')
+
+        self._fully_initialized = False
+        if not lazy_init:
+            self.full_init()
+
+    @property
+    def meta(self) -> dict:
+        """Get the meta information of the first dataset in ``self.datasets``.
+
+        Returns:
+            dict: Meta information of first dataset.
+        """
+        # Prevent `self._meta` from being modified by outside.
+        return copy.deepcopy(self._meta)
+
+    def full_init(self):
+        """Loop to ``full_init`` each dataset."""
+        if self._fully_initialized:
+            return
+        for d in self.datasets:
+            d.full_init()
+        # Get the cumulative sizes of `self.datasets`. For example, the length
+        # of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9]
+        super().__init__(self.datasets)
+        self._fully_initialized = True
+
+    @force_full_init
+    def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]:
+        """Convert global idx to local index.
+
+        Args:
+            idx (int): Global index of ``RepeatDataset``.
+
+        Returns:
+            Tuple[int, int]: The index of ``self.datasets`` and the local
+            index of data.
+        """
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError(
+                    f'absolute value of index({idx}) should not exceed dataset'
+                    f'length({len(self)}).')
+            idx = len(self) + idx
+        # Get the inner index of single dataset
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+
+        return dataset_idx, sample_idx
+
+    @force_full_init
+    def get_data_info(self, idx: int) -> dict:
+        """Get annotation by index.
+
+        Args:
+            idx (int): Global index of ``ConcatDataset``.
+
+        Returns:
+            dict: The idx-th annotation of the datasets.
+        """
+        dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
+        return self.datasets[dataset_idx].get_data_info(sample_idx)
+
+    @force_full_init
+    def __len__(self):
+        return super().__len__()
+
+    def __getitem__(self, idx):
+        if not self._fully_initialized:
+            warnings.warn('Please call `full_init` method manually to '
+                          'accelerate the speed.')
+            self.full_init()
+        dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
+        return self.datasets[dataset_idx][sample_idx]
+
+
+class RepeatDataset:
+    """A wrapper of repeated dataset.
+
+    The length of repeated dataset will be `times` larger than the original
+    dataset. This is useful when the data loading time is long but the dataset
+    is small. Using RepeatDataset can reduce the data loading time between
+    epochs.
+
+    Args:
+        dataset (BaseDataset): The dataset to be repeated.
+        times (int): Repeat times.
+        lazy_init (bool, optional): Whether to load annotation during
+            instantiation. Defaults to False.
+    """
+
+    def __init__(self,
+                 dataset: BaseDataset,
+                 times: int,
+                 lazy_init: bool = False):
+        self.dataset = dataset
+        self.times = times
+        self._meta = dataset.meta
+
+        self._fully_initialized = False
+        if not lazy_init:
+            self.full_init()
+
+    @property
+    def meta(self) -> dict:
+        """Get the meta information of the repeated dataset.
+
+        Returns:
+            dict: The meta information of repeated dataset.
+        """
+        return copy.deepcopy(self._meta)
+
+    def full_init(self):
+        """Loop to ``full_init`` each dataset."""
+        if self._fully_initialized:
+            return
+
+        self.dataset.full_init()
+        self._ori_len = len(self.dataset)
+        self._fully_initialized = True
+
+    @force_full_init
+    def _get_ori_dataset_idx(self, idx: int) -> int:
+        """Convert global index to local index.
+
+        Args:
+            idx: Global index of ``RepeatDataset``.
+
+        Returns:
+            idx (int): Local index of data.
+        """
+        return idx % self._ori_len
+
+    @force_full_init
+    def get_data_info(self, idx: int) -> dict:
+        """Get annotation by index.
+
+        Args:
+            idx (int): Global index of ``ConcatDataset``.
+
+        Returns:
+            dict: The idx-th annotation of the datasets.
+        """
+        sample_idx = self._get_ori_dataset_idx(idx)
+        return self.dataset.get_data_info(sample_idx)
+
+    def __getitem__(self, idx):
+        if not self._fully_initialized:
+            warnings.warn('Please call `full_init` method manually to '
+                          'accelerate the speed.')
+            self.full_init()
+
+        sample_idx = self._get_ori_dataset_idx(idx)
+        return self.dataset[sample_idx]
+
+    @force_full_init
+    def __len__(self):
+        return self.times * self._ori_len
+
+
+class ClassBalancedDataset:
+    """A wrapper of class balanced dataset.
+
+    Suitable for training on class imbalanced datasets like LVIS. Following
+    the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_,
+    in each epoch, an image may appear multiple times based on its
+    "repeat factor".
+    The repeat factor for an image is a function of the frequency the rarest
+    category labeled in that image. The "frequency of category c" in [0, 1]
+    is defined by the fraction of images in the training set (without repeats)
+    in which category c appears.
+    The dataset needs to instantiate :meth:`get_cat_ids` to support
+    ClassBalancedDataset.
+
+    The repeat factor is computed as followed.
+
+    1. For each category c, compute the fraction # of images
+       that contain it: :math:`f(c)`
+    2. For each category c, compute the category-level repeat factor:
+       :math:`r(c) = max(1, sqrt(t/f(c)))`
+    3. For each image I, compute the image-level repeat factor:
+       :math:`r(I) = max_{c in I} r(c)`
+
+    Args:
+        dataset (BaseDataset): The dataset to be repeated.
+        oversample_thr (float): frequency threshold below which data is
+            repeated. For categories with ``f_c >= oversample_thr``, there is
+            no oversampling. For categories with ``f_c < oversample_thr``, the
+            degree of oversampling following the square-root inverse frequency
+            heuristic above.
+        lazy_init (bool, optional): whether to load annotation during
+            instantiation. Defaults to False
+    """
+
+    def __init__(self,
+                 dataset: BaseDataset,
+                 oversample_thr: float,
+                 lazy_init: bool = False):
+        self.dataset = dataset
+        self.oversample_thr = oversample_thr
+        self._meta = dataset.meta
+
+        self._fully_initialized = False
+        if not lazy_init:
+            self.full_init()
+
+    @property
+    def meta(self) -> dict:
+        """Get the meta information of the repeated dataset.
+
+        Returns:
+            dict: The meta information of repeated dataset.
+        """
+        return copy.deepcopy(self._meta)
+
+    def full_init(self):
+        """Loop to ``full_init`` each dataset."""
+        if self._fully_initialized:
+            return
+
+        self.dataset.full_init()
+
+        repeat_factors = self._get_repeat_factors(self.dataset,
+                                                  self.oversample_thr)
+        repeat_indices = []
+        for dataset_index, repeat_factor in enumerate(repeat_factors):
+            repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
+        self.repeat_indices = repeat_indices
+
+        self._fully_initialized = True
+
+    def _get_repeat_factors(self, dataset: BaseDataset,
+                            repeat_thr: float) -> List[float]:
+        """Get repeat factor for each images in the dataset.
+
+        Args:
+            dataset (BaseDataset): The dataset.
+            repeat_thr (float): The threshold of frequency. If an image
+                contains the categories whose frequency below the threshold,
+                it would be repeated.
+
+        Returns:
+            List[float]: The repeat factors for each images in the dataset.
+        """
+        # 1. For each category c, compute the fraction # of images
+        #   that contain it: f(c)
+        category_freq: defaultdict = defaultdict(float)
+        num_images = len(dataset)
+        for idx in range(num_images):
+            cat_ids = set(self.dataset.get_cat_ids(idx))
+            for cat_id in cat_ids:
+                category_freq[cat_id] += 1
+        for k, v in category_freq.items():
+            assert v > 0, f'caterogy {k} does not contain any images'
+            category_freq[k] = v / num_images
+
+        # 2. For each category c, compute the category-level repeat factor:
+        #    r(c) = max(1, sqrt(t/f(c)))
+        category_repeat = {
+            cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
+            for cat_id, cat_freq in category_freq.items()
+        }
+
+        # 3. For each image I and its labels L(I), compute the image-level
+        # repeat factor:
+        #    r(I) = max_{c in L(I)} r(c)
+        repeat_factors = []
+        for idx in range(num_images):
+            cat_ids = set(self.dataset.get_cat_ids(idx))
+            repeat_factor = max(
+                {category_repeat[cat_id]
+                 for cat_id in cat_ids})
+            repeat_factors.append(repeat_factor)
+
+        return repeat_factors
+
+    @force_full_init
+    def _get_ori_dataset_idx(self, idx: int) -> int:
+        """Convert global index to local index.
+
+        Args:
+            idx (int): Global index of ``RepeatDataset``.
+
+        Returns:
+            int: Local index of data.
+        """
+        return self.repeat_indices[idx]
+
+    @force_full_init
+    def get_cat_ids(self, idx: int) -> List[int]:
+        """Get category ids of class balanced dataset by index.
+
+        Args:
+            idx (int): Index of data.
+
+        Returns:
+            List[int]: All categories in the image of specified index.
+        """
+        sample_idx = self._get_ori_dataset_idx(idx)
+        return self.dataset.get_cat_ids(sample_idx)
+
+    @force_full_init
+    def get_data_info(self, idx: int) -> dict:
+        """Get annotation by index.
+
+        Args:
+            idx (int): Global index of ``ConcatDataset``.
+
+        Returns:
+            dict: The idx-th annotation of the dataset.
+        """
+        sample_idx = self._get_ori_dataset_idx(idx)
+        return self.dataset.get_data_info(sample_idx)
+
+    def __getitem__(self, idx):
+        warnings.warn('Please call `full_init` method manually to '
+                      'accelerate the speed.')
+        if not self._fully_initialized:
+            self.full_init()
+
+        ori_index = self._get_ori_dataset_idx(idx)
+        return self.dataset[ori_index]
+
+    @force_full_init
+    def __len__(self):
+        return len(self.repeat_indices)
diff --git a/tests/data/annotations/annotation_wrong_format.json b/tests/data/annotations/annotation_wrong_format.json
new file mode 100644
index 00000000..b8dadc0d
--- /dev/null
+++ b/tests/data/annotations/annotation_wrong_format.json
@@ -0,0 +1,5 @@
+[
+    {
+        "img_path": "test_img.jpg"
+    }
+]
\ No newline at end of file
diff --git a/tests/data/annotations/wrong_annotation.json b/tests/data/annotations/annotation_wrong_keys.json
similarity index 100%
rename from tests/data/annotations/wrong_annotation.json
rename to tests/data/annotations/annotation_wrong_keys.json
diff --git a/tests/data/annotations/dummy_annotation.json b/tests/data/annotations/dummy_annotation.json
index abba398a..d36109ce 100644
--- a/tests/data/annotations/dummy_annotation.json
+++ b/tests/data/annotations/dummy_annotation.json
@@ -2,7 +2,8 @@
     "metadata":
     {
       "dataset_type": "test_dataset",
-      "task_name": "test_task"
+      "task_name": "test_task",
+      "empty_list": []
     },
     "data_infos":
     [
diff --git a/tests/data/meta/classes.txt b/tests/data/meta/classes.txt
new file mode 100644
index 00000000..18a619c9
--- /dev/null
+++ b/tests/data/meta/classes.txt
@@ -0,0 +1 @@
+dog
diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py
index 2a33427c..e695488c 100644
--- a/tests/test_data/test_base_dataset.py
+++ b/tests/test_data/test_base_dataset.py
@@ -1,40 +1,66 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import copy
 import os.path as osp
 from unittest.mock import MagicMock
 
 import pytest
 import torch
 
-from mmengine.data import (BaseDataset, ClassBalancedDataset, ConcatDataset,
-                           RepeatDataset)
+from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
+                              ConcatDataset, RepeatDataset, force_full_init)
+from mmengine.registry import TRANSFORMS
 
 
-class TestBaseDataset:
+def function_pipeline(data_info):
+    return data_info
+
+
+@TRANSFORMS.register_module()
+class CallableTransform:
+
+    def __call__(self, data_info):
+        return data_info
 
-    def __init__(self):
-        self.base_dataset = BaseDataset
 
-        self.data_info = dict(filename='test_img.jpg', height=604, width=640)
-        self.base_dataset.parse_annotations = MagicMock(
-            return_value=self.data_info)
+@TRANSFORMS.register_module()
+class NotCallableTransform:
+    pass
 
-        self.imgs = torch.rand((2, 3, 32, 32))
-        self.base_dataset.pipeline = MagicMock(
-            return_value=dict(imgs=self.imgs))
+
+class TestBaseDataset:
+    dataset_type = BaseDataset
+    data_info = dict(
+        filename='test_img.jpg', height=604, width=640, sample_idx=0)
+    imgs = torch.rand((2, 3, 32, 32))
+    pipeline = MagicMock(return_value=dict(imgs=imgs))
+    META: dict = dict()
+    parse_annotations = MagicMock(return_value=data_info)
+
+    def _init_dataset(self):
+        self.dataset_type.META = self.META
+        self.dataset_type.parse_annotations = self.parse_annotations
 
     def test_init(self):
+        self._init_dataset()
         # test the instantiation of self.base_dataset
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
         assert dataset._fully_initialized
         assert hasattr(dataset, 'data_infos')
         assert hasattr(dataset, 'data_address')
+        dataset = self.dataset_type(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img=None),
+            ann_file='annotations/dummy_annotation.json')
+        assert dataset._fully_initialized
+        assert hasattr(dataset, 'data_infos')
+        assert hasattr(dataset, 'data_address')
 
         # test the instantiation of self.base_dataset with
         # `serialize_data=False`
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -42,33 +68,54 @@ class TestBaseDataset:
         assert dataset._fully_initialized
         assert hasattr(dataset, 'data_infos')
         assert not hasattr(dataset, 'data_address')
+        assert len(dataset) == 2
+        assert dataset.get_data_info(0) == self.data_info
 
         # test the instantiation of self.base_dataset with lazy init
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             lazy_init=True)
         assert not dataset._fully_initialized
-        assert not hasattr(dataset, 'data_infos')
+        assert not dataset.data_infos
+
+        # test the instantiation of self.base_dataset if ann_file is not
+        # existed.
+        with pytest.raises(FileNotFoundError):
+            self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img='imgs'),
+                ann_file='annotations/not_existed_annotation.json')
 
         # test the instantiation of self.base_dataset when the ann_file is
         # wrong
         with pytest.raises(ValueError):
-            self.base_dataset(
+            self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img='imgs'),
+                ann_file='annotations/annotation_wrong_keys.json')
+        with pytest.raises(TypeError):
+            self.dataset_type(
                 data_root=osp.join(osp.dirname(__file__), '../data/'),
                 data_prefix=dict(img='imgs'),
-                ann_file='annotations/wrong_annotation.json')
+                ann_file='annotations/annotation_wrong_format.json')
+        with pytest.raises(TypeError):
+            self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img=['img']),
+                ann_file='annotations/annotation_wrong_format.json')
 
         # test the instantiation of self.base_dataset when `parse_annotations`
         # return `list[dict]`
-        self.base_dataset.parse_annotations = MagicMock(
+        self.dataset_type.parse_annotations = MagicMock(
             return_value=[self.data_info,
                           self.data_info.copy()])
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
+        dataset.pipeline = self.pipeline
         assert dataset._fully_initialized
         assert hasattr(dataset, 'data_infos')
         assert hasattr(dataset, 'data_address')
@@ -76,98 +123,139 @@ class TestBaseDataset:
         assert dataset[0] == dict(imgs=self.imgs)
         assert dataset.get_data_info(0) == self.data_info
 
-        # set self.base_dataset to initial state
-        self.__init__()
+        # test the instantiation of self.base_dataset when `parse_annotations`
+        # return unsupported data.
+        with pytest.raises(TypeError):
+            self.dataset_type.parse_annotations = MagicMock(return_value='xxx')
+            dataset = self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img='imgs'),
+                ann_file='annotations/dummy_annotation.json')
+        with pytest.raises(TypeError):
+            self.dataset_type.parse_annotations = MagicMock(
+                return_value=[self.data_info, 'xxx'])
+            dataset = self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img='imgs'),
+                ann_file='annotations/dummy_annotation.json')
 
     def test_meta(self):
+        self._init_dataset()
         # test dataset.meta with setting the meta from annotation file as the
         # meta of self.base_dataset
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
+
         assert dataset.meta == dict(
-            dataset_type='test_dataset', task_name='test_task')
+            dataset_type='test_dataset', task_name='test_task', empty_list=[])
 
         # test dataset.meta with setting META in self.base_dataset
         dataset_type = 'new_dataset'
-        self.base_dataset.META = dict(
+        self.dataset_type.META = dict(
             dataset_type=dataset_type, classes=('dog', 'cat'))
 
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
         assert dataset.meta == dict(
             dataset_type=dataset_type,
             task_name='test_task',
-            classes=('dog', 'cat'))
+            classes=('dog', 'cat'),
+            empty_list=[])
 
         # test dataset.meta with passing meta into self.base_dataset
-        meta = dict(classes=('dog', ))
-        dataset = self.base_dataset(
+        meta = dict(classes=('dog', ), task_name='new_task')
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             meta=meta)
-        assert self.base_dataset.META == dict(
+        assert self.dataset_type.META == dict(
             dataset_type=dataset_type, classes=('dog', 'cat'))
         assert dataset.meta == dict(
             dataset_type=dataset_type,
-            task_name='test_task',
-            classes=('dog', ))
+            task_name='new_task',
+            classes=('dog', ),
+            empty_list=[])
         # reset `base_dataset.META`, the `dataset.meta` should not change
-        self.base_dataset.META['classes'] = ('dog', 'cat', 'fish')
-        assert self.base_dataset.META == dict(
+        self.dataset_type.META['classes'] = ('dog', 'cat', 'fish')
+        assert self.dataset_type.META == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
+        assert dataset.meta == dict(
+            dataset_type=dataset_type,
+            task_name='new_task',
+            classes=('dog', ),
+            empty_list=[])
+
+        # test dataset.meta with passing meta containing a file into
+        # self.base_dataset
+        meta = dict(
+            classes=osp.join(
+                osp.dirname(__file__), '../data/meta/classes.txt'))
+        dataset = self.dataset_type(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img='imgs'),
+            ann_file='annotations/dummy_annotation.json',
+            meta=meta)
         assert dataset.meta == dict(
             dataset_type=dataset_type,
             task_name='test_task',
-            classes=('dog', ))
+            classes=['dog'],
+            empty_list=[])
+
+        # test dataset.meta with passing unsupported meta into
+        # self.base_dataset
+        with pytest.raises(TypeError):
+            meta = 'dog'
+            dataset = self.dataset_type(
+                data_root=osp.join(osp.dirname(__file__), '../data/'),
+                data_prefix=dict(img='imgs'),
+                ann_file='annotations/dummy_annotation.json',
+                meta=meta)
 
         # test dataset.meta with passing meta into self.base_dataset and
         # lazy_init is True
         meta = dict(classes=('dog', ))
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             meta=meta,
             lazy_init=True)
-        # 'task_name' not in dataset.meta
+        # 'task_name' and 'empty_list' not in dataset.meta
         assert dataset.meta == dict(
             dataset_type=dataset_type, classes=('dog', ))
 
         # test whether self.base_dataset.META is changed when a customize
         # dataset inherit self.base_dataset
         # test reset META in ToyDataset.
-        class ToyDataset(self.base_dataset):
+        class ToyDataset(self.dataset_type):
             META = dict(xxx='xxx')
 
         assert ToyDataset.META == dict(xxx='xxx')
-        assert self.base_dataset.META == dict(
+        assert self.dataset_type.META == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
 
         # test update META in ToyDataset.
-        class ToyDataset(self.base_dataset):
-            self.base_dataset.META['classes'] = ('bird', )
+        class ToyDataset(self.dataset_type):
+            META = copy.deepcopy(self.dataset_type.META)
+            META['classes'] = ('bird', )
 
         assert ToyDataset.META == dict(
             dataset_type=dataset_type, classes=('bird', ))
-        assert self.base_dataset.META == dict(
+        assert self.dataset_type.META == dict(
             dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
 
-        # set self.base_dataset to initial state
-        self.__init__()
-
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_length(self, lazy_init):
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             lazy_init=lazy_init)
-
         if not lazy_init:
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
@@ -175,20 +263,49 @@ class TestBaseDataset:
         else:
             # test `__len__()` when lazy_init is True
             assert not dataset._fully_initialized
-            assert not hasattr(dataset, 'data_infos')
+            assert not dataset.data_infos
             # call `full_init()` automatically
             assert len(dataset) == 2
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
 
+    def test_compose(self):
+        # test callable transform
+        transforms = [function_pipeline]
+        compose = Compose(transforms=transforms)
+        assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
+        # test transform build from cfg_dict
+        transforms = [dict(type='CallableTransform')]
+        compose = Compose(transforms=transforms)
+        assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
+        # test return None in advance
+        none_func = MagicMock(return_value=None)
+        transforms = [none_func, function_pipeline]
+        compose = Compose(transforms=transforms)
+        assert compose(dict(img=self.imgs)) is None
+        # test repr
+        repr_str = f'Compose(\n' \
+                   f'    {none_func}\n' \
+                   f'    {function_pipeline}\n' \
+                   f')'
+        assert repr(compose) == repr_str
+        # non-callable transform will raise error
+        with pytest.raises(TypeError):
+            transforms = [dict(type='NotCallableTransform')]
+            Compose(transforms)
+
+        # transform must be callable or dict
+        with pytest.raises(TypeError):
+            Compose([1])
+
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_getitem(self, lazy_init):
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             lazy_init=lazy_init)
-
+        dataset.pipeline = self.pipeline
         if not lazy_init:
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
@@ -196,15 +313,26 @@ class TestBaseDataset:
         else:
             # test `__getitem__()` when lazy_init is True
             assert not dataset._fully_initialized
-            assert not hasattr(dataset, 'data_infos')
+            assert not dataset.data_infos
             # call `full_init()` automatically
             assert dataset[0] == dict(imgs=self.imgs)
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
 
+        # test with test mode
+        dataset.test_mode = True
+        assert dataset[0] == dict(imgs=self.imgs)
+
+        pipeline = MagicMock(return_value=None)
+        dataset.pipeline = pipeline
+        # test cannot get a valid image.
+        dataset.test_mode = False
+        with pytest.raises(Exception):
+            dataset[0]
+
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_get_data_info(self, lazy_init):
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
@@ -217,20 +345,32 @@ class TestBaseDataset:
         else:
             # test `get_data_info()` when lazy_init is True
             assert not dataset._fully_initialized
-            assert not hasattr(dataset, 'data_infos')
+            assert not dataset.data_infos
             # call `full_init()` automatically
             assert dataset.get_data_info(0) == self.data_info
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
 
+    def test_force_full_init(self):
+        with pytest.raises(AttributeError):
+
+            class ClassWithoutFullInit:
+
+                @force_full_init
+                def foo(self):
+                    pass
+
+            class_without_full_init = ClassWithoutFullInit()
+            class_without_full_init.foo()
+
     @pytest.mark.parametrize('lazy_init', [True, False])
     def test_full_init(self, lazy_init):
-        dataset = self.base_dataset(
+        dataset = self.dataset_type(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             lazy_init=lazy_init)
-
+        dataset.pipeline = self.pipeline
         if not lazy_init:
             assert dataset._fully_initialized
             assert hasattr(dataset, 'data_infos')
@@ -240,7 +380,7 @@ class TestBaseDataset:
         else:
             # test `full_init()` when lazy_init is True
             assert not dataset._fully_initialized
-            assert not hasattr(dataset, 'data_infos')
+            assert not dataset.data_infos
             # call `full_init()` manually
             dataset.full_init()
             assert dataset._fully_initialized
@@ -249,55 +389,116 @@ class TestBaseDataset:
             assert dataset[0] == dict(imgs=self.imgs)
             assert dataset.get_data_info(0) == self.data_info
 
+    def test_slice_data(self):
+        # test the instantiation of self.base_dataset when passing num_samples
+        dataset = self.dataset_type(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img=None),
+            ann_file='annotations/dummy_annotation.json',
+            num_samples=1)
+        assert len(dataset) == 1
+
+    def test_rand_another(self):
+        # test the instantiation of self.base_dataset when passing num_samples
+        dataset = self.dataset_type(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img=None),
+            ann_file='annotations/dummy_annotation.json',
+            num_samples=1)
+        assert dataset._rand_another() >= 0
+        assert dataset._rand_another() < len(dataset)
+
 
 class TestConcatDataset:
 
-    def __init__(self):
+    def _init_dataset(self):
         dataset = BaseDataset
 
         # create dataset_a
         data_info = dict(filename='test_img.jpg', height=604, width=640)
         dataset.parse_annotations = MagicMock(return_value=data_info)
         imgs = torch.rand((2, 3, 32, 32))
-        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
+
         self.dataset_a = dataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
+        self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
 
         # create dataset_b
         data_info = dict(filename='gray.jpg', height=288, width=512)
         dataset.parse_annotations = MagicMock(return_value=data_info)
         imgs = torch.rand((2, 3, 32, 32))
-        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
         self.dataset_b = dataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json',
             meta=dict(classes=('dog', 'cat')))
-
+        self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
         # test init
         self.cat_datasets = ConcatDataset(
             datasets=[self.dataset_a, self.dataset_b])
 
+    def test_full_init(self):
+        dataset = BaseDataset
+
+        # create dataset_a
+        data_info = dict(filename='test_img.jpg', height=604, width=640)
+        dataset.parse_annotations = MagicMock(return_value=data_info)
+        imgs = torch.rand((2, 3, 32, 32))
+
+        dataset_a = dataset(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img='imgs'),
+            ann_file='annotations/dummy_annotation.json')
+        dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))
+
+        # create dataset_b
+        data_info = dict(filename='gray.jpg', height=288, width=512)
+        dataset.parse_annotations = MagicMock(return_value=data_info)
+        imgs = torch.rand((2, 3, 32, 32))
+        dataset_b = dataset(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img='imgs'),
+            ann_file='annotations/dummy_annotation.json',
+            meta=dict(classes=('dog', 'cat')))
+        dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
+        # test init with lazy_init=True
+        cat_datasets = ConcatDataset(
+            datasets=[dataset_a, dataset_b], lazy_init=True)
+        cat_datasets.full_init()
+        assert len(cat_datasets) == 4
+        cat_datasets.full_init()
+        cat_datasets._fully_initialized = False
+        cat_datasets[1]
+        assert len(cat_datasets) == 4
+
     def test_meta(self):
+        self._init_dataset()
         assert self.cat_datasets.meta == self.dataset_a.meta
         # meta of self.cat_datasets is from the first dataset when
         # concatnating datasets with different metas.
         assert self.cat_datasets.meta != self.dataset_b.meta
 
     def test_length(self):
+        self._init_dataset()
         assert len(self.cat_datasets) == (
             len(self.dataset_a) + len(self.dataset_b))
 
     def test_getitem(self):
-        assert self.cat_datasets[0] == self.dataset_a[0]
-        assert self.cat_datasets[0] != self.dataset_b[0]
+        self._init_dataset()
+        assert (
+            self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
+        assert (self.cat_datasets[0]['imgs'] !=
+                self.dataset_b[0]['imgs']).all()
 
-        assert self.cat_datasets[-1] == self.dataset_b[-1]
-        assert self.cat_datasets[-1] != self.dataset_a[-1]
+        assert (
+            self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
+        assert (self.cat_datasets[-1]['imgs'] !=
+                self.dataset_a[-1]['imgs']).all()
 
     def test_get_data_info(self):
+        self._init_dataset()
         assert self.cat_datasets.get_data_info(
             0) == self.dataset_a.get_data_info(0)
         assert self.cat_datasets.get_data_info(
@@ -306,40 +507,76 @@ class TestConcatDataset:
         assert self.cat_datasets.get_data_info(
             -1) == self.dataset_b.get_data_info(-1)
         assert self.cat_datasets.get_data_info(
-            -1) != self.dataset_a[-1].get_data_info(-1)
+            -1) != self.dataset_a.get_data_info(-1)
+
+    def test_get_ori_dataset_idx(self):
+        self._init_dataset()
+        assert self.cat_datasets._get_ori_dataset_idx(3) == (
+            1, 3 - len(self.dataset_a))
+        assert self.cat_datasets._get_ori_dataset_idx(-1) == (
+            1, len(self.dataset_b) - 1)
+        with pytest.raises(ValueError):
+            assert self.cat_datasets._get_ori_dataset_idx(-10)
 
 
 class TestRepeatDataset:
 
-    def __init__(self):
+    def _init_dataset(self):
         dataset = BaseDataset
         data_info = dict(filename='test_img.jpg', height=604, width=640)
         dataset.parse_annotations = MagicMock(return_value=data_info)
         imgs = torch.rand((2, 3, 32, 32))
-        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
         self.dataset = dataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
+        self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
 
         self.repeat_times = 5
         # test init
         self.repeat_datasets = RepeatDataset(
             dataset=self.dataset, times=self.repeat_times)
 
+    def test_full_init(self):
+        dataset = BaseDataset
+        data_info = dict(filename='test_img.jpg', height=604, width=640)
+        dataset.parse_annotations = MagicMock(return_value=data_info)
+        imgs = torch.rand((2, 3, 32, 32))
+        dataset = dataset(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img='imgs'),
+            ann_file='annotations/dummy_annotation.json')
+        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
+
+        repeat_times = 5
+        # test init
+        repeat_datasets = RepeatDataset(
+            dataset=dataset, times=repeat_times, lazy_init=True)
+
+        repeat_datasets.full_init()
+        assert len(repeat_datasets) == repeat_times * len(dataset)
+        repeat_datasets.full_init()
+        repeat_datasets._fully_initialized = False
+        repeat_datasets[1]
+        assert len(repeat_datasets) == repeat_times * len(dataset)
+
     def test_meta(self):
+        self._init_dataset()
         assert self.repeat_datasets.meta == self.dataset.meta
 
     def test_length(self):
+        self._init_dataset()
         assert len(
             self.repeat_datasets) == len(self.dataset) * self.repeat_times
 
     def test_getitem(self):
+        self._init_dataset()
         for i in range(self.repeat_times):
             assert self.repeat_datasets[len(self.dataset) *
                                         i] == self.dataset[0]
 
     def test_get_data_info(self):
+        self._init_dataset()
         for i in range(self.repeat_times):
             assert self.repeat_datasets.get_data_info(
                 len(self.dataset) * i) == self.dataset.get_data_info(0)
@@ -347,17 +584,17 @@ class TestRepeatDataset:
 
 class TestClassBalancedDataset:
 
-    def __init__(self):
+    def _init_dataset(self):
         dataset = BaseDataset
         data_info = dict(filename='test_img.jpg', height=604, width=640)
         dataset.parse_annotations = MagicMock(return_value=data_info)
         imgs = torch.rand((2, 3, 32, 32))
-        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
         dataset.get_cat_ids = MagicMock(return_value=[0])
         self.dataset = dataset(
             data_root=osp.join(osp.dirname(__file__), '../data/'),
             data_prefix=dict(img='imgs'),
             ann_file='annotations/dummy_annotation.json')
+        self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
 
         self.repeat_indices = [0, 0, 1, 1, 1]
         # test init
@@ -365,23 +602,54 @@ class TestClassBalancedDataset:
             dataset=self.dataset, oversample_thr=1e-3)
         self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
 
+    def test_full_init(self):
+        dataset = BaseDataset
+        data_info = dict(filename='test_img.jpg', height=604, width=640)
+        dataset.parse_annotations = MagicMock(return_value=data_info)
+        imgs = torch.rand((2, 3, 32, 32))
+        dataset.get_cat_ids = MagicMock(return_value=[0])
+        dataset = dataset(
+            data_root=osp.join(osp.dirname(__file__), '../data/'),
+            data_prefix=dict(img='imgs'),
+            ann_file='annotations/dummy_annotation.json')
+        dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))
+
+        repeat_indices = [0, 0, 1, 1, 1]
+        # test init
+        cls_banlanced_datasets = ClassBalancedDataset(
+            dataset=dataset, oversample_thr=1e-3, lazy_init=True)
+
+        cls_banlanced_datasets.full_init()
+        cls_banlanced_datasets.repeat_indices = repeat_indices
+        assert len(cls_banlanced_datasets) == len(repeat_indices)
+        cls_banlanced_datasets.full_init()
+        cls_banlanced_datasets._fully_initialized = False
+        cls_banlanced_datasets[1]
+        cls_banlanced_datasets.repeat_indices = repeat_indices
+        assert len(cls_banlanced_datasets) == len(repeat_indices)
+
     def test_meta(self):
+        self._init_dataset()
         assert self.cls_banlanced_datasets.meta == self.dataset.meta
 
     def test_length(self):
+        self._init_dataset()
         assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
 
     def test_getitem(self):
+        self._init_dataset()
         for i in range(len(self.repeat_indices)):
             assert self.cls_banlanced_datasets[i] == self.dataset[
                 self.repeat_indices[i]]
 
     def test_get_data_info(self):
+        self._init_dataset()
         for i in range(len(self.repeat_indices)):
             assert self.cls_banlanced_datasets.get_data_info(
                 i) == self.dataset.get_data_info(self.repeat_indices[i])
 
     def test_get_cat_ids(self):
+        self._init_dataset()
         for i in range(len(self.repeat_indices)):
             assert self.cls_banlanced_datasets.get_cat_ids(
                 i) == self.dataset.get_cat_ids(self.repeat_indices[i])
-- 
GitLab