From ada6660c65be862594d2f3dd0991229b7f619d50 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 22 Feb 2022 14:01:06 +0800 Subject: [PATCH] [Feature] add base dataset (#32) * basedataset first commit * add base dataset * add dataset * add basedataset * Fix test dataset * Fix mypy and test * Fix mypy and test * remove unused code * Update mmengine/dataset/base_dataset.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/dataset/base_dataset.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add more corner cases in unittest * fix lint * Fix as comment * Fix lint * update unitest * Type hint Dick to dict * rename max_refetch * Fix as comment * Fix typo * Fix as comment * BaseDataset is no more an abstrac Class, change UT and docs * Fix as comment * Fix as comment and refactor type error * Add comment for full init * Fix as comment and modify dataset_wrapper * Fix as comment and modify dataset_wrapper * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Tao Gong <gongtao950513@gmail.com> --- docs/zh_cn/tutorials/basedataset.md | 6 +- mmengine/__init__.py | 1 + mmengine/dataset/__init__.py | 4 + mmengine/dataset/base_dataset.py | 553 ++++++++++++++++++ mmengine/dataset/dataset_wrapper.py | 364 ++++++++++++ .../annotations/annotation_wrong_format.json | 5 + ...tation.json => annotation_wrong_keys.json} | 0 tests/data/annotations/dummy_annotation.json | 3 +- tests/data/meta/classes.txt | 1 + tests/test_data/test_base_dataset.py | 404 ++++++++++--- 10 files changed, 1268 insertions(+), 73 deletions(-) create mode 100644 mmengine/dataset/__init__.py create mode 100644 mmengine/dataset/base_dataset.py create mode 100644 mmengine/dataset/dataset_wrapper.py create mode 100644 tests/data/annotations/annotation_wrong_format.json rename tests/data/annotations/{wrong_annotation.json => annotation_wrong_keys.json} (100%) create mode 100644 tests/data/meta/classes.txt diff --git a/docs/zh_cn/tutorials/basedataset.md b/docs/zh_cn/tutorials/basedataset.md index 08a49af9..e941a733 100644 --- a/docs/zh_cn/tutorials/basedataset.md +++ b/docs/zh_cn/tutorials/basedataset.md @@ -69,7 +69,7 @@ data 2. 构建数æ®æµæ°´çº¿ï¼ˆdata pipeline),用于数æ®é¢„处ç†ä¸Žæ•°æ®å‡†å¤‡ï¼› -3. 读å–与解æžæ»¡è¶³ OpenMMLab 2.0 æ•°æ®é›†æ ¼å¼è§„èŒƒçš„æ ‡æ³¨æ–‡ä»¶ï¼Œè¯¥æ¥éª¤ä¸ä¼šæœ‰ `parse_annotations()` 抽象方法,该抽象方法负责解æžæ ‡æ³¨æ–‡ä»¶é‡Œçš„æ¯ä¸ªåŽŸå§‹æ•°æ®ï¼› +3. 读å–与解æžæ»¡è¶³ OpenMMLab 2.0 æ•°æ®é›†æ ¼å¼è§„èŒƒçš„æ ‡æ³¨æ–‡ä»¶ï¼Œè¯¥æ¥éª¤ä¸ä¼šæœ‰ `parse_annotations()` 方法,该方法负责解æžæ ‡æ³¨æ–‡ä»¶é‡Œçš„æ¯ä¸ªåŽŸå§‹æ•°æ®ï¼› 4. è¿‡æ»¤æ— ç”¨æ•°æ®ï¼Œæ¯”如ä¸åŒ…å«æ ‡æ³¨çš„æ ·æœ¬ç‰ï¼› @@ -77,8 +77,6 @@ data 6. åºåˆ—åŒ–å…¨éƒ¨æ ·æœ¬ï¼Œä»¥è¾¾åˆ°èŠ‚çœå†…å˜çš„效果,详情请å‚考[节çœå†…å˜](#节çœå†…å˜)。 -æ•°æ®é›†åŸºç±»æ˜¯ä¸€ä¸ªæŠ½è±¡ç±»ï¼Œå®ƒæœ‰ä¸”åªæœ‰ä¸€ä¸ªæŠ½è±¡æ–¹æ³• `parse_annotations()` ,`parse_annotations()` å®šä¹‰äº†å°†æ ‡æ³¨æ–‡ä»¶é‡Œçš„ä¸€ä¸ªåŽŸå§‹æ•°æ®å¤„ç†æˆä¸€ä¸ªæˆ–若干个è®ç»ƒ/æµ‹è¯•æ ·æœ¬çš„æ–¹æ³•ã€‚å› æ¤å¯¹äºŽè‡ªå®šä¹‰æ•°æ®é›†ç±»ï¼Œç”¨æˆ·å¿…é¡»è¦å®žçŽ° `parse_annotations()` 方法。 - ### æ•°æ®é›†åŸºç±»æä¾›çš„æŽ¥å£ ä¸Ž `torch.utils.data.Dataset` 类似,数æ®é›†åˆå§‹åŒ–åŽï¼Œæ”¯æŒ `__getitem__` 方法,用æ¥ç´¢å¼•æ•°æ®ï¼Œä»¥åŠ `__len__` æ“作获å–æ•°æ®é›†å¤§å°ï¼Œé™¤æ¤ä¹‹å¤–,OpenMMLab çš„æ•°æ®é›†åŸºç±»ä¸»è¦æ供了以下接å£æ¥è®¿é—®å…·ä½“ä¿¡æ¯ï¼š @@ -93,7 +91,7 @@ data ## 使用数æ®é›†åŸºç±»è‡ªå®šä¹‰æ•°æ®é›†ç±» -在了解了数æ®é›†åŸºç±»çš„åˆå§‹åŒ–æµç¨‹ä¸Žæ供的接å£ä¹‹åŽï¼Œå°±å¯ä»¥åŸºäºŽæ•°æ®é›†åŸºç±»è‡ªå®šä¹‰æ•°æ®é›†ç±»ï¼Œå¦‚上所述,数æ®é›†åŸºç±»æ˜¯ä¸€ä¸ªæŠ½è±¡ç±»ï¼Œå®ƒæœ‰ä¸”åªæœ‰ä¸€ä¸ªæŠ½è±¡æ–¹æ³• `parse_annotations()`ï¼Œå› æ¤ç”¨æˆ·å¿…须在自定义数æ®é›†ç±»ä¸å®žçŽ°è¯¥æ–¹æ³•ã€‚以下是一个使用数æ®é›†åŸºç±»æ¥å®žçŽ°æŸä¸€å…·ä½“æ•°æ®é›†çš„例å。 +在了解了数æ®é›†åŸºç±»çš„åˆå§‹åŒ–æµç¨‹ä¸Žæ供的接å£ä¹‹åŽï¼Œå°±å¯ä»¥åŸºäºŽæ•°æ®é›†åŸºç±»è‡ªå®šä¹‰æ•°æ®é›†ç±»ï¼Œå¦‚上所述,对于满足 OpenMMLab 2.0 æ•°æ®é›†æ ¼å¼è§„èŒƒçš„æ ‡æ³¨æ–‡ä»¶ï¼Œç”¨æˆ·å¯ä»¥é‡è½½ `parse_annotations()`æ¥åŠ è½½æ ‡ç¾ã€‚以下是一个使用数æ®é›†åŸºç±»æ¥å®žçŽ°æŸä¸€å…·ä½“æ•°æ®é›†çš„例å。 ```python import os.path as osp diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 34a3acc1..d389ac84 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. # flake8: noqa from .config import * +from .dataset import * from .fileio import * from .registry import * from .utils import * diff --git a/mmengine/dataset/__init__.py b/mmengine/dataset/__init__.py new file mode 100644 index 00000000..25ee724b --- /dev/null +++ b/mmengine/dataset/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# flake8: noqa +from .base_dataset import BaseDataset, Compose, force_full_init +from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py new file mode 100644 index 00000000..586292eb --- /dev/null +++ b/mmengine/dataset/base_dataset.py @@ -0,0 +1,553 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import functools +import gc +import os.path as osp +import pickle +import warnings +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union + +import numpy as np +from torch.utils.data import Dataset + +from mmengine.fileio import list_from_file, load +from mmengine.registry import TRANSFORMS +from mmengine.utils import check_file_exist + + +class Compose: + """Compose multiple transforms sequentially. + + Args: + transforms (Sequence[dict, callable]): Sequence of transform object or + config dict to be composed. + """ + + def __init__(self, transforms: Sequence[Union[dict, Callable]]): + self.transforms: List[Callable] = [] + for transform in transforms: + if isinstance(transform, dict): + transform = TRANSFORMS.build(transform) + if not callable(transform): + raise TypeError(f'transform should be a callable object, ' + f'but got {type(transform)}') + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError( + f'transform must be a callable object or dict, ' + f'but got {type(transform)}') + + def __call__(self, data: dict) -> Optional[dict]: + """Call function to apply transforms sequentially. + + Args: + data (dict): A result dict contains the data to transform. + + Returns: + dict: Transformed data. + """ + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + """Print ``self.transforms`` in sequence. + + Returns: + str: Formatted string. + """ + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += f' {t}' + format_string += '\n)' + return format_string + + +def force_full_init(old_func: Callable) -> Any: + """Those methods decorated by ``force_full_init`` will be forced to call + ``full_init`` if the instance has not been fully initiated. + + Args: + old_func (Callable): Decorated function, make sure the first arg is an + instance with ``full_init`` method. + + Returns: + Any: Depends on old_func. + """ + # TODO This decorator can also be used in runner. + @functools.wraps(old_func) + def wrapper(obj: object, *args, **kwargs): + if not hasattr(obj, 'full_init'): + raise AttributeError(f'{type(obj)} does not have full_init ' + 'method.') + if not getattr(obj, '_fully_initialized', False): + warnings.warn('Attribute `_fully_initialized` is not defined in ' + f'{type(obj)} or `type(obj)._fully_initialized is ' + 'False, `full_init` will be called and ' + f'{type(obj)}._fully_initialized will be set to ' + 'True') + obj.full_init() # type: ignore + obj._fully_initialized = True # type: ignore + + return old_func(obj, *args, **kwargs) + + return wrapper + + +class BaseDataset(Dataset): + r"""BaseDataset for open source projects in OpenMMLab. + + The annotation format is shown as follows. + + .. code-block:: none + + { + "metadata": + { + "dataset_type": "test_dataset", + "task_name": "test_task" + }, + "data_infos": + [ + { + "img_path": "test_img.jpg", + "height": 604, + "width": 640, + "instances": + [ + { + "bbox": [0, 0, 10, 20], + "bbox_label": 1, + "mask": [[0,0],[0,10],[10,20],[20,0]], + "extra_anns": [1,2,3] + }, + { + "bbox": [10, 10, 110, 120], + "bbox_label": 2, + "mask": [[10,10],[10,110],[110,120],[120,10]], + "extra_anns": [4,5,6] + } + ] + }, + ] + } + + Args: + ann_file (str): Annotation file path. + meta (dict, optional): Meta information for dataset, such as class + information. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img=None, ann=None). + filter_cfg (dict, optional): Config for filter data. Defaults to None. + num_samples (int, optional): Support using first few data in + annotation file to facilitate training/testing on a smaller + dataset. Defaults to -1 which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Defaults to False. + max_refetch (int, optional): The maximum number of cycles to get a + valid image. Defaults to 1000. + + Note: + BaseDataset collects meta information from `annotation file` (the + lowest priority), ``BaseDataset.META``(medium) and `meta parameter` + (highest) passed to constructors. The lower priority meta information + will be overwritten by higher one. + """ + + META: dict = dict() + _fully_initialized: bool = False + + def __init__(self, + ann_file: str, + meta: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img=None, ann=None), + filter_cfg: Optional[dict] = None, + num_samples: int = -1, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000): + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._num_samples = num_samples + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_infos: List[dict] = [] + self.data_infos_bytes: bytearray = bytearray() + + # set meta information + self._meta = self._get_meta_data(copy.deepcopy(meta)) + + # join paths + if self.data_root is not None: + self._join_prefix() + + # build pipeline + self.pipeline = Compose(pipeline) + + if not lazy_init: + self.full_init() + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index and automatically call ``full_init`` if the + dataset has not been fully initialized. + + Args: + idx (int): The index of data. + + Returns: + dict: The idx-th annotation of the dataset. + """ + if self.serialize_data: + start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() + end_addr = self.data_address[idx].item() + bytes = memoryview(self.data_infos_bytes[start_addr:end_addr]) + data_info = pickle.loads(bytes) + else: + data_info = self.data_infos[idx] + # To record the real positive index of data information. + if idx >= 0: + data_info['sample_idx'] = idx + else: + data_info['sample_idx'] = len(self) + idx + + return data_info + + def full_init(self): + """Load annotation file and set ``BaseDataset._fully_initialized`` to + True. + + If ``lazy_init=False``, ``full_init`` will be called during the + instantiation and ``self._fully_initialized`` will be set to True. If + ``obj._fully_initialized=False``, the class method decorated by + ``force_full_init`` will call ``full_init`` automatically. + + Several steps to initialize annotation: + + - load_annotations: Load annotations from annotation file. + - filter data information: Filter annotations according to + filter_cfg. + - slice_data: Slice dataset according to ``self._num_samples`` + - serialize_data: Serialize ``self.data_infos`` if + ``self.serialize_data`` is True. + """ + if self._fully_initialized: + return + + # load data information + self.data_infos = self.load_annotations(self.ann_file) + # filter illegal data, such as data that has no annotations. + self.data_infos = self.filter_data() + # if `num_samples > 0`, return the first `num_samples` data information + self.data_infos = self._slice_data() + # serialize data_infos + if self.serialize_data: + self.data_infos_bytes, self.data_address = self._serialize_data() + # Empty cache for preventing making multiple copies of + # `self.data_info` when loading data multi-processes. + self.data_infos.clear() + gc.collect() + self._fully_initialized = True + + @property + def meta(self) -> dict: + """Get meta information of dataset. + + Returns: + dict: meta information collected from ``BaseDataset.META``, + annotation file and meta parameter during instantiation. + """ + return copy.deepcopy(self._meta) + + def parse_annotations(self, + raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format. + + ``parse_annotations`` should return ``dict`` or ``List[dict]``. Each + dict contains the annotations of a training sample. If the protocol of + the sample annotations is changed, this function can be overridden to + update the parsing logic while keeping compatibility. + + Args: + raw_data_info (dict): Raw annotation load from ``ann_file`` + + Returns: + Union[dict, List[dict]]: Parsed annotation. + """ + return raw_data_info + + def filter_data(self) -> List[dict]: + """Filter annotations according to filter_cfg. Defaults return all + ``data_infos``. + + If some ``data_infos`` could be filtered according to specific logic, + the subclass should override this method. + + Returns: + List[dict]: Filtered results. + """ + return self.data_infos + + def get_cat_ids(self, idx: int) -> List[int]: + """Get category ids by index. Dataset wrapped by ClassBalancedDataset + must implement this method. + + The ``ClassBalancedDataset`` requires a subclass which implements this + method. + + Args: + idx (int): The index of data. + + Returns: + List[int]: All categories in the image of specified index. + """ + raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' + 'method') + + def __getitem__(self, idx: int) -> dict: + """Get the idx-th image of dataset after ``self.pipelines`` and + ``full_init`` will be called if the dataset has not been fully + initialized. + + During training phase, if ``self.pipelines`` get ``None``, + ``self._rand_another`` will be called until a valid image is fetched or + the maximum limit of refetech is reached. + + Args: + idx (int): The index of self.data_infos + + Returns: + dict: The idx-th image of dataset after ``self.pipelines``. + """ + if not self._fully_initialized: + warnings.warn( + 'Please call `full_init()` method manually to accelerate ' + 'the speed.') + self.full_init() + + if self.test_mode: + return self._prepare_data(idx) + + for _ in range(self.max_refetch): + data_sample = self._prepare_data(idx) + if data_sample is None: + idx = self._rand_another() + continue + return data_sample + + raise Exception(f'Cannot find valid image after {self.max_refetch}! ' + 'Please check your image path and pipelines') + + def load_annotations(self, ann_file: str) -> List[dict]: + """Load annotations from an annotation file. + + If the annotation file does not follow `OpenMMLab 2.0 format dataset + <https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ . + The subclass must override this method for load annotations. + + Args: + ann_file (str): Absolute annotation file path if ``self.root=None`` + or relative path if ``self.root=/path/to/data/``. + + Returns: + List[dict]: A list of annotation. + """ # noqa: E501 + check_file_exist(ann_file) + annotations = load(ann_file) + if not isinstance(annotations, dict): + raise TypeError(f'The annotations loaded from annotation file ' + f'should be a dict, but got {type(annotations)}!') + if 'data_infos' not in annotations or 'metadata' not in annotations: + raise ValueError('Annotation must have data_infos and metadata ' + 'keys') + meta_data = annotations['metadata'] + raw_data_infos = annotations['data_infos'] + + # update self._meta + for k, v in meta_data.items(): + # We only merge keys that are not contained in self._meta. + self._meta.setdefault(k, v) + + # load and parse data_infos + data_infos = [] + for raw_data_info in raw_data_infos: + data_info = self.parse_annotations(raw_data_info) + if isinstance(data_info, dict): + data_infos.append(data_info) + elif isinstance(data_info, list): + # data_info can also be a list of dict, which means one + # data_info contains multiple samples. + for item in data_info: + if not isinstance(item, dict): + raise TypeError('data_info must be a list[dict], but ' + f'got {type(item)}') + data_infos.extend(data_info) + else: + raise TypeError('data_info should be a dict or list[dict], ' + f'but got {type(data_info)}') + + return data_infos + + @classmethod + def _get_meta_data(cls, in_meta: dict = None) -> dict: + """Collect meta information from the dictionary of meta. + + Args: + in_meta (dict): Meta information dict. If ``in_meta`` contains + existed filename, it will be parsed by ``list_from_file``. + + Returns: + dict: Parsed meta information. + """ + # cls.META will be overwritten by in_meta + cls_meta = copy.deepcopy(cls.META) + if in_meta is None: + return cls_meta + if not isinstance(in_meta, dict): + raise TypeError( + f'in_meta should be a dict, but got {type(in_meta)}') + + for k, v in in_meta.items(): + if isinstance(v, str) and osp.isfile(v): + # if filename in in_meta, this key will be further parsed. + # nested filename will be ignored.: + cls_meta[k] = list_from_file(v) + else: + cls_meta[k] = v + + return cls_meta + + def _join_prefix(self): + """Join ``self.data_root`` with ``self.data_prefix`` and + ``self.ann_file``. + + Examples: + >>> # self.data_prefix contains relative paths + >>> self.data_root = 'a/b/c' + >>> self.data_prefix = dict(img='d/e/') + >>> self.ann_file = 'f' + >>> self._join_prefix() + >>> self.data_prefix + dict(img='a/b/c/d/e') + >>> self.ann_file + 'a/b/c/f' + >>> # self.data_prefix contains absolute paths + >>> self.data_root = 'a/b/c' + >>> self.data_prefix = dict(img='/d/e/') + >>> self.ann_file = 'f' + >>> self._join_prefix() + >>> self.data_prefix + dict(img='/d/e') + >>> self.ann_file + 'a/b/c/f' + """ + if not osp.isabs(self.ann_file): + self.ann_file = osp.join(self.data_root, self.ann_file) + + for data_key, prefix in self.data_prefix.items(): + if prefix is None: + self.data_prefix[data_key] = self.data_root + elif isinstance(prefix, str): + if not osp.isabs(prefix): + self.data_prefix[data_key] = osp.join( + self.data_root, prefix) + else: + raise TypeError('prefix should be a string or None, but got ' + f'{type(prefix)}') + + def _slice_data(self) -> List[dict]: + """Slice ``self.data_infos``. BaseDataset supports only using the first + few data. + + Returns: + List[dict]: A slice of ``self.data_infos`` + """ + assert self._num_samples < len(self.data_infos), \ + f'Slice size({self._num_samples}) is larger than dataset ' \ + f'size({self.data_infos}, please keep `num_sample` smaller than' \ + f'{self.data_infos})' + if self._num_samples > 0: + return self.data_infos[:self._num_samples] + else: + return self.data_infos + + def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: + """Serialize ``self.data_infos`` to save memory when launching multiple + workers in data loading. This function will be called in ``full_init``. + + Hold memory using serialized objects, and data loader workers can use + shared RAM from master process instead of making a copy. + + Returns: + Tuple[np.ndarray, np.ndarray]: serialize result and corresponding + address. + """ + + def _serialize(data): + buffer = pickle.dumps(data, protocol=4) + return np.frombuffer(buffer, dtype=np.uint8) + + serialized_data_infos_list = [_serialize(x) for x in self.data_infos] + address_list = np.asarray([len(x) for x in serialized_data_infos_list], + dtype=np.int64) + data_address: np.ndarray = np.cumsum(address_list) + serialized_data_infos = np.concatenate(serialized_data_infos_list) + + return serialized_data_infos, data_address + + def _rand_another(self) -> int: + """Get random index. + + Returns: + int: Random index from 0 to ``len(self)-1`` + """ + return np.random.randint(0, len(self)) + + def _prepare_data(self, idx) -> Any: + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + Any: Depends on ``self.pipeline``. + """ + data_info = self.get_data_info(idx) + return self.pipeline(data_info) + + @force_full_init + def __len__(self) -> int: + """Get the length of filtered dataset and automatically call + ``full_init`` if the dataset has not been fully init. + + Returns: + int: The length of filtered dataset. + """ + if self.serialize_data: + return len(self.data_address) + else: + return len(self.data_infos) diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py new file mode 100644 index 00000000..ba8c83cf --- /dev/null +++ b/mmengine/dataset/dataset_wrapper.py @@ -0,0 +1,364 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import bisect +import copy +import math +import warnings +from collections import defaultdict +from typing import List, Sequence, Tuple + +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset + +from .base_dataset import BaseDataset, force_full_init + + +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + + Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. + + Args: + datasets (Sequence[BaseDataset]): A list of datasets which will be + concatenated. + lazy_init (bool, optional): Whether to load annotation during + instantiation. Defaults to False. + """ + + def __init__(self, + datasets: Sequence[BaseDataset], + lazy_init: bool = False): + # Only use meta of first dataset. + self._meta = datasets[0].meta + self.datasets = datasets # type: ignore + for i, dataset in enumerate(datasets, 1): + if self._meta != dataset.meta: + warnings.warn( + f'The meta information of the {i}-th dataset does not ' + 'match meta information of the first dataset') + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def meta(self) -> dict: + """Get the meta information of the first dataset in ``self.datasets``. + + Returns: + dict: Meta information of first dataset. + """ + # Prevent `self._meta` from being modified by outside. + return copy.deepcopy(self._meta) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + for d in self.datasets: + d.full_init() + # Get the cumulative sizes of `self.datasets`. For example, the length + # of `self.datasets` is [2, 3, 4], the cumulative sizes is [2, 5, 9] + super().__init__(self.datasets) + self._fully_initialized = True + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> Tuple[int, int]: + """Convert global idx to local index. + + Args: + idx (int): Global index of ``RepeatDataset``. + + Returns: + Tuple[int, int]: The index of ``self.datasets`` and the local + index of data. + """ + if idx < 0: + if -idx > len(self): + raise ValueError( + f'absolute value of index({idx}) should not exceed dataset' + f'length({len(self)}).') + idx = len(self) + idx + # Get the inner index of single dataset + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + + return dataset_idx, sample_idx + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) + return self.datasets[dataset_idx].get_data_info(sample_idx) + + @force_full_init + def __len__(self): + return super().__len__() + + def __getitem__(self, idx): + if not self._fully_initialized: + warnings.warn('Please call `full_init` method manually to ' + 'accelerate the speed.') + self.full_init() + dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) + return self.datasets[dataset_idx][sample_idx] + + +class RepeatDataset: + """A wrapper of repeated dataset. + + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + + Args: + dataset (BaseDataset): The dataset to be repeated. + times (int): Repeat times. + lazy_init (bool, optional): Whether to load annotation during + instantiation. Defaults to False. + """ + + def __init__(self, + dataset: BaseDataset, + times: int, + lazy_init: bool = False): + self.dataset = dataset + self.times = times + self._meta = dataset.meta + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def meta(self) -> dict: + """Get the meta information of the repeated dataset. + + Returns: + dict: The meta information of repeated dataset. + """ + return copy.deepcopy(self._meta) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + self._ori_len = len(self.dataset) + self._fully_initialized = True + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> int: + """Convert global index to local index. + + Args: + idx: Global index of ``RepeatDataset``. + + Returns: + idx (int): Local index of data. + """ + return idx % self._ori_len + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the datasets. + """ + sample_idx = self._get_ori_dataset_idx(idx) + return self.dataset.get_data_info(sample_idx) + + def __getitem__(self, idx): + if not self._fully_initialized: + warnings.warn('Please call `full_init` method manually to ' + 'accelerate the speed.') + self.full_init() + + sample_idx = self._get_ori_dataset_idx(idx) + return self.dataset[sample_idx] + + @force_full_init + def __len__(self): + return self.times * self._ori_len + + +class ClassBalancedDataset: + """A wrapper of class balanced dataset. + + Suitable for training on class imbalanced datasets like LVIS. Following + the sampling strategy in the `paper <https://arxiv.org/abs/1908.03195>`_, + in each epoch, an image may appear multiple times based on its + "repeat factor". + The repeat factor for an image is a function of the frequency the rarest + category labeled in that image. The "frequency of category c" in [0, 1] + is defined by the fraction of images in the training set (without repeats) + in which category c appears. + The dataset needs to instantiate :meth:`get_cat_ids` to support + ClassBalancedDataset. + + The repeat factor is computed as followed. + + 1. For each category c, compute the fraction # of images + that contain it: :math:`f(c)` + 2. For each category c, compute the category-level repeat factor: + :math:`r(c) = max(1, sqrt(t/f(c)))` + 3. For each image I, compute the image-level repeat factor: + :math:`r(I) = max_{c in I} r(c)` + + Args: + dataset (BaseDataset): The dataset to be repeated. + oversample_thr (float): frequency threshold below which data is + repeated. For categories with ``f_c >= oversample_thr``, there is + no oversampling. For categories with ``f_c < oversample_thr``, the + degree of oversampling following the square-root inverse frequency + heuristic above. + lazy_init (bool, optional): whether to load annotation during + instantiation. Defaults to False + """ + + def __init__(self, + dataset: BaseDataset, + oversample_thr: float, + lazy_init: bool = False): + self.dataset = dataset + self.oversample_thr = oversample_thr + self._meta = dataset.meta + + self._fully_initialized = False + if not lazy_init: + self.full_init() + + @property + def meta(self) -> dict: + """Get the meta information of the repeated dataset. + + Returns: + dict: The meta information of repeated dataset. + """ + return copy.deepcopy(self._meta) + + def full_init(self): + """Loop to ``full_init`` each dataset.""" + if self._fully_initialized: + return + + self.dataset.full_init() + + repeat_factors = self._get_repeat_factors(self.dataset, + self.oversample_thr) + repeat_indices = [] + for dataset_index, repeat_factor in enumerate(repeat_factors): + repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) + self.repeat_indices = repeat_indices + + self._fully_initialized = True + + def _get_repeat_factors(self, dataset: BaseDataset, + repeat_thr: float) -> List[float]: + """Get repeat factor for each images in the dataset. + + Args: + dataset (BaseDataset): The dataset. + repeat_thr (float): The threshold of frequency. If an image + contains the categories whose frequency below the threshold, + it would be repeated. + + Returns: + List[float]: The repeat factors for each images in the dataset. + """ + # 1. For each category c, compute the fraction # of images + # that contain it: f(c) + category_freq: defaultdict = defaultdict(float) + num_images = len(dataset) + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + for cat_id in cat_ids: + category_freq[cat_id] += 1 + for k, v in category_freq.items(): + assert v > 0, f'caterogy {k} does not contain any images' + category_freq[k] = v / num_images + + # 2. For each category c, compute the category-level repeat factor: + # r(c) = max(1, sqrt(t/f(c))) + category_repeat = { + cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + # 3. For each image I and its labels L(I), compute the image-level + # repeat factor: + # r(I) = max_{c in L(I)} r(c) + repeat_factors = [] + for idx in range(num_images): + cat_ids = set(self.dataset.get_cat_ids(idx)) + repeat_factor = max( + {category_repeat[cat_id] + for cat_id in cat_ids}) + repeat_factors.append(repeat_factor) + + return repeat_factors + + @force_full_init + def _get_ori_dataset_idx(self, idx: int) -> int: + """Convert global index to local index. + + Args: + idx (int): Global index of ``RepeatDataset``. + + Returns: + int: Local index of data. + """ + return self.repeat_indices[idx] + + @force_full_init + def get_cat_ids(self, idx: int) -> List[int]: + """Get category ids of class balanced dataset by index. + + Args: + idx (int): Index of data. + + Returns: + List[int]: All categories in the image of specified index. + """ + sample_idx = self._get_ori_dataset_idx(idx) + return self.dataset.get_cat_ids(sample_idx) + + @force_full_init + def get_data_info(self, idx: int) -> dict: + """Get annotation by index. + + Args: + idx (int): Global index of ``ConcatDataset``. + + Returns: + dict: The idx-th annotation of the dataset. + """ + sample_idx = self._get_ori_dataset_idx(idx) + return self.dataset.get_data_info(sample_idx) + + def __getitem__(self, idx): + warnings.warn('Please call `full_init` method manually to ' + 'accelerate the speed.') + if not self._fully_initialized: + self.full_init() + + ori_index = self._get_ori_dataset_idx(idx) + return self.dataset[ori_index] + + @force_full_init + def __len__(self): + return len(self.repeat_indices) diff --git a/tests/data/annotations/annotation_wrong_format.json b/tests/data/annotations/annotation_wrong_format.json new file mode 100644 index 00000000..b8dadc0d --- /dev/null +++ b/tests/data/annotations/annotation_wrong_format.json @@ -0,0 +1,5 @@ +[ + { + "img_path": "test_img.jpg" + } +] \ No newline at end of file diff --git a/tests/data/annotations/wrong_annotation.json b/tests/data/annotations/annotation_wrong_keys.json similarity index 100% rename from tests/data/annotations/wrong_annotation.json rename to tests/data/annotations/annotation_wrong_keys.json diff --git a/tests/data/annotations/dummy_annotation.json b/tests/data/annotations/dummy_annotation.json index abba398a..d36109ce 100644 --- a/tests/data/annotations/dummy_annotation.json +++ b/tests/data/annotations/dummy_annotation.json @@ -2,7 +2,8 @@ "metadata": { "dataset_type": "test_dataset", - "task_name": "test_task" + "task_name": "test_task", + "empty_list": [] }, "data_infos": [ diff --git a/tests/data/meta/classes.txt b/tests/data/meta/classes.txt new file mode 100644 index 00000000..18a619c9 --- /dev/null +++ b/tests/data/meta/classes.txt @@ -0,0 +1 @@ +dog diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index 2a33427c..e695488c 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -1,40 +1,66 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os.path as osp from unittest.mock import MagicMock import pytest import torch -from mmengine.data import (BaseDataset, ClassBalancedDataset, ConcatDataset, - RepeatDataset) +from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose, + ConcatDataset, RepeatDataset, force_full_init) +from mmengine.registry import TRANSFORMS -class TestBaseDataset: +def function_pipeline(data_info): + return data_info + + +@TRANSFORMS.register_module() +class CallableTransform: + + def __call__(self, data_info): + return data_info - def __init__(self): - self.base_dataset = BaseDataset - self.data_info = dict(filename='test_img.jpg', height=604, width=640) - self.base_dataset.parse_annotations = MagicMock( - return_value=self.data_info) +@TRANSFORMS.register_module() +class NotCallableTransform: + pass - self.imgs = torch.rand((2, 3, 32, 32)) - self.base_dataset.pipeline = MagicMock( - return_value=dict(imgs=self.imgs)) + +class TestBaseDataset: + dataset_type = BaseDataset + data_info = dict( + filename='test_img.jpg', height=604, width=640, sample_idx=0) + imgs = torch.rand((2, 3, 32, 32)) + pipeline = MagicMock(return_value=dict(imgs=imgs)) + META: dict = dict() + parse_annotations = MagicMock(return_value=data_info) + + def _init_dataset(self): + self.dataset_type.META = self.META + self.dataset_type.parse_annotations = self.parse_annotations def test_init(self): + self._init_dataset() # test the instantiation of self.base_dataset - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') assert hasattr(dataset, 'data_address') + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json') + assert dataset._fully_initialized + assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_address') # test the instantiation of self.base_dataset with # `serialize_data=False` - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', @@ -42,33 +68,54 @@ class TestBaseDataset: assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') assert not hasattr(dataset, 'data_address') + assert len(dataset) == 2 + assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=True) assert not dataset._fully_initialized - assert not hasattr(dataset, 'data_infos') + assert not dataset.data_infos + + # test the instantiation of self.base_dataset if ann_file is not + # existed. + with pytest.raises(FileNotFoundError): + self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/not_existed_annotation.json') # test the instantiation of self.base_dataset when the ann_file is # wrong with pytest.raises(ValueError): - self.base_dataset( + self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/annotation_wrong_keys.json') + with pytest.raises(TypeError): + self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), - ann_file='annotations/wrong_annotation.json') + ann_file='annotations/annotation_wrong_format.json') + with pytest.raises(TypeError): + self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=['img']), + ann_file='annotations/annotation_wrong_format.json') # test the instantiation of self.base_dataset when `parse_annotations` # return `list[dict]` - self.base_dataset.parse_annotations = MagicMock( + self.dataset_type.parse_annotations = MagicMock( return_value=[self.data_info, self.data_info.copy()]) - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') + dataset.pipeline = self.pipeline assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') assert hasattr(dataset, 'data_address') @@ -76,98 +123,139 @@ class TestBaseDataset: assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info - # set self.base_dataset to initial state - self.__init__() + # test the instantiation of self.base_dataset when `parse_annotations` + # return unsupported data. + with pytest.raises(TypeError): + self.dataset_type.parse_annotations = MagicMock(return_value='xxx') + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + with pytest.raises(TypeError): + self.dataset_type.parse_annotations = MagicMock( + return_value=[self.data_info, 'xxx']) + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') def test_meta(self): + self._init_dataset() # test dataset.meta with setting the meta from annotation file as the # meta of self.base_dataset - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') + assert dataset.meta == dict( - dataset_type='test_dataset', task_name='test_task') + dataset_type='test_dataset', task_name='test_task', empty_list=[]) # test dataset.meta with setting META in self.base_dataset dataset_type = 'new_dataset' - self.base_dataset.META = dict( + self.dataset_type.META = dict( dataset_type=dataset_type, classes=('dog', 'cat')) - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset.meta == dict( dataset_type=dataset_type, task_name='test_task', - classes=('dog', 'cat')) + classes=('dog', 'cat'), + empty_list=[]) # test dataset.meta with passing meta into self.base_dataset - meta = dict(classes=('dog', )) - dataset = self.base_dataset( + meta = dict(classes=('dog', ), task_name='new_task') + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', meta=meta) - assert self.base_dataset.META == dict( + assert self.dataset_type.META == dict( dataset_type=dataset_type, classes=('dog', 'cat')) assert dataset.meta == dict( dataset_type=dataset_type, - task_name='test_task', - classes=('dog', )) + task_name='new_task', + classes=('dog', ), + empty_list=[]) # reset `base_dataset.META`, the `dataset.meta` should not change - self.base_dataset.META['classes'] = ('dog', 'cat', 'fish') - assert self.base_dataset.META == dict( + self.dataset_type.META['classes'] = ('dog', 'cat', 'fish') + assert self.dataset_type.META == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert dataset.meta == dict( + dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) + + # test dataset.meta with passing meta containing a file into + # self.base_dataset + meta = dict( + classes=osp.join( + osp.dirname(__file__), '../data/meta/classes.txt')) + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + meta=meta) assert dataset.meta == dict( dataset_type=dataset_type, task_name='test_task', - classes=('dog', )) + classes=['dog'], + empty_list=[]) + + # test dataset.meta with passing unsupported meta into + # self.base_dataset + with pytest.raises(TypeError): + meta = 'dog' + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + meta=meta) # test dataset.meta with passing meta into self.base_dataset and # lazy_init is True meta = dict(classes=('dog', )) - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', meta=meta, lazy_init=True) - # 'task_name' not in dataset.meta + # 'task_name' and 'empty_list' not in dataset.meta assert dataset.meta == dict( dataset_type=dataset_type, classes=('dog', )) # test whether self.base_dataset.META is changed when a customize # dataset inherit self.base_dataset # test reset META in ToyDataset. - class ToyDataset(self.base_dataset): + class ToyDataset(self.dataset_type): META = dict(xxx='xxx') assert ToyDataset.META == dict(xxx='xxx') - assert self.base_dataset.META == dict( + assert self.dataset_type.META == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) # test update META in ToyDataset. - class ToyDataset(self.base_dataset): - self.base_dataset.META['classes'] = ('bird', ) + class ToyDataset(self.dataset_type): + META = copy.deepcopy(self.dataset_type.META) + META['classes'] = ('bird', ) assert ToyDataset.META == dict( dataset_type=dataset_type, classes=('bird', )) - assert self.base_dataset.META == dict( + assert self.dataset_type.META == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - # set self.base_dataset to initial state - self.__init__() - @pytest.mark.parametrize('lazy_init', [True, False]) def test_length(self, lazy_init): - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) - if not lazy_init: assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') @@ -175,20 +263,49 @@ class TestBaseDataset: else: # test `__len__()` when lazy_init is True assert not dataset._fully_initialized - assert not hasattr(dataset, 'data_infos') + assert not dataset.data_infos # call `full_init()` automatically assert len(dataset) == 2 assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') + def test_compose(self): + # test callable transform + transforms = [function_pipeline] + compose = Compose(transforms=transforms) + assert (self.imgs == compose(dict(img=self.imgs))['img']).all() + # test transform build from cfg_dict + transforms = [dict(type='CallableTransform')] + compose = Compose(transforms=transforms) + assert (self.imgs == compose(dict(img=self.imgs))['img']).all() + # test return None in advance + none_func = MagicMock(return_value=None) + transforms = [none_func, function_pipeline] + compose = Compose(transforms=transforms) + assert compose(dict(img=self.imgs)) is None + # test repr + repr_str = f'Compose(\n' \ + f' {none_func}\n' \ + f' {function_pipeline}\n' \ + f')' + assert repr(compose) == repr_str + # non-callable transform will raise error + with pytest.raises(TypeError): + transforms = [dict(type='NotCallableTransform')] + Compose(transforms) + + # transform must be callable or dict + with pytest.raises(TypeError): + Compose([1]) + @pytest.mark.parametrize('lazy_init', [True, False]) def test_getitem(self, lazy_init): - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) - + dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') @@ -196,15 +313,26 @@ class TestBaseDataset: else: # test `__getitem__()` when lazy_init is True assert not dataset._fully_initialized - assert not hasattr(dataset, 'data_infos') + assert not dataset.data_infos # call `full_init()` automatically assert dataset[0] == dict(imgs=self.imgs) assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') + # test with test mode + dataset.test_mode = True + assert dataset[0] == dict(imgs=self.imgs) + + pipeline = MagicMock(return_value=None) + dataset.pipeline = pipeline + # test cannot get a valid image. + dataset.test_mode = False + with pytest.raises(Exception): + dataset[0] + @pytest.mark.parametrize('lazy_init', [True, False]) def test_get_data_info(self, lazy_init): - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', @@ -217,20 +345,32 @@ class TestBaseDataset: else: # test `get_data_info()` when lazy_init is True assert not dataset._fully_initialized - assert not hasattr(dataset, 'data_infos') + assert not dataset.data_infos # call `full_init()` automatically assert dataset.get_data_info(0) == self.data_info assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') + def test_force_full_init(self): + with pytest.raises(AttributeError): + + class ClassWithoutFullInit: + + @force_full_init + def foo(self): + pass + + class_without_full_init = ClassWithoutFullInit() + class_without_full_init.foo() + @pytest.mark.parametrize('lazy_init', [True, False]) def test_full_init(self, lazy_init): - dataset = self.base_dataset( + dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', lazy_init=lazy_init) - + dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized assert hasattr(dataset, 'data_infos') @@ -240,7 +380,7 @@ class TestBaseDataset: else: # test `full_init()` when lazy_init is True assert not dataset._fully_initialized - assert not hasattr(dataset, 'data_infos') + assert not dataset.data_infos # call `full_init()` manually dataset.full_init() assert dataset._fully_initialized @@ -249,55 +389,116 @@ class TestBaseDataset: assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info + def test_slice_data(self): + # test the instantiation of self.base_dataset when passing num_samples + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + num_samples=1) + assert len(dataset) == 1 + + def test_rand_another(self): + # test the instantiation of self.base_dataset when passing num_samples + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + num_samples=1) + assert dataset._rand_another() >= 0 + assert dataset._rand_another() < len(dataset) + class TestConcatDataset: - def __init__(self): + def _init_dataset(self): dataset = BaseDataset # create dataset_a data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_annotations = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self.dataset_a = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') + self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) # create dataset_b data_info = dict(filename='gray.jpg', height=288, width=512) dataset.parse_annotations = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.dataset_b = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', meta=dict(classes=('dog', 'cat'))) - + self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init self.cat_datasets = ConcatDataset( datasets=[self.dataset_a, self.dataset_b]) + def test_full_init(self): + dataset = BaseDataset + + # create dataset_a + data_info = dict(filename='test_img.jpg', height=604, width=640) + dataset.parse_annotations = MagicMock(return_value=data_info) + imgs = torch.rand((2, 3, 32, 32)) + + dataset_a = dataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) + + # create dataset_b + data_info = dict(filename='gray.jpg', height=288, width=512) + dataset.parse_annotations = MagicMock(return_value=data_info) + imgs = torch.rand((2, 3, 32, 32)) + dataset_b = dataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + meta=dict(classes=('dog', 'cat'))) + dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) + # test init with lazy_init=True + cat_datasets = ConcatDataset( + datasets=[dataset_a, dataset_b], lazy_init=True) + cat_datasets.full_init() + assert len(cat_datasets) == 4 + cat_datasets.full_init() + cat_datasets._fully_initialized = False + cat_datasets[1] + assert len(cat_datasets) == 4 + def test_meta(self): + self._init_dataset() assert self.cat_datasets.meta == self.dataset_a.meta # meta of self.cat_datasets is from the first dataset when # concatnating datasets with different metas. assert self.cat_datasets.meta != self.dataset_b.meta def test_length(self): + self._init_dataset() assert len(self.cat_datasets) == ( len(self.dataset_a) + len(self.dataset_b)) def test_getitem(self): - assert self.cat_datasets[0] == self.dataset_a[0] - assert self.cat_datasets[0] != self.dataset_b[0] + self._init_dataset() + assert ( + self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] != + self.dataset_b[0]['imgs']).all() - assert self.cat_datasets[-1] == self.dataset_b[-1] - assert self.cat_datasets[-1] != self.dataset_a[-1] + assert ( + self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] != + self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): + self._init_dataset() assert self.cat_datasets.get_data_info( 0) == self.dataset_a.get_data_info(0) assert self.cat_datasets.get_data_info( @@ -306,40 +507,76 @@ class TestConcatDataset: assert self.cat_datasets.get_data_info( -1) == self.dataset_b.get_data_info(-1) assert self.cat_datasets.get_data_info( - -1) != self.dataset_a[-1].get_data_info(-1) + -1) != self.dataset_a.get_data_info(-1) + + def test_get_ori_dataset_idx(self): + self._init_dataset() + assert self.cat_datasets._get_ori_dataset_idx(3) == ( + 1, 3 - len(self.dataset_a)) + assert self.cat_datasets._get_ori_dataset_idx(-1) == ( + 1, len(self.dataset_b) - 1) + with pytest.raises(ValueError): + assert self.cat_datasets._get_ori_dataset_idx(-10) class TestRepeatDataset: - def __init__(self): + def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_annotations = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') + self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_times = 5 # test init self.repeat_datasets = RepeatDataset( dataset=self.dataset, times=self.repeat_times) + def test_full_init(self): + dataset = BaseDataset + data_info = dict(filename='test_img.jpg', height=604, width=640) + dataset.parse_annotations = MagicMock(return_value=data_info) + imgs = torch.rand((2, 3, 32, 32)) + dataset = dataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + + repeat_times = 5 + # test init + repeat_datasets = RepeatDataset( + dataset=dataset, times=repeat_times, lazy_init=True) + + repeat_datasets.full_init() + assert len(repeat_datasets) == repeat_times * len(dataset) + repeat_datasets.full_init() + repeat_datasets._fully_initialized = False + repeat_datasets[1] + assert len(repeat_datasets) == repeat_times * len(dataset) + def test_meta(self): + self._init_dataset() assert self.repeat_datasets.meta == self.dataset.meta def test_length(self): + self._init_dataset() assert len( self.repeat_datasets) == len(self.dataset) * self.repeat_times def test_getitem(self): + self._init_dataset() for i in range(self.repeat_times): assert self.repeat_datasets[len(self.dataset) * i] == self.dataset[0] def test_get_data_info(self): + self._init_dataset() for i in range(self.repeat_times): assert self.repeat_datasets.get_data_info( len(self.dataset) * i) == self.dataset.get_data_info(0) @@ -347,17 +584,17 @@ class TestRepeatDataset: class TestClassBalancedDataset: - def __init__(self): + def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_annotations = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) dataset.get_cat_ids = MagicMock(return_value=[0]) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') + self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_indices = [0, 0, 1, 1, 1] # test init @@ -365,23 +602,54 @@ class TestClassBalancedDataset: dataset=self.dataset, oversample_thr=1e-3) self.cls_banlanced_datasets.repeat_indices = self.repeat_indices + def test_full_init(self): + dataset = BaseDataset + data_info = dict(filename='test_img.jpg', height=604, width=640) + dataset.parse_annotations = MagicMock(return_value=data_info) + imgs = torch.rand((2, 3, 32, 32)) + dataset.get_cat_ids = MagicMock(return_value=[0]) + dataset = dataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json') + dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + + repeat_indices = [0, 0, 1, 1, 1] + # test init + cls_banlanced_datasets = ClassBalancedDataset( + dataset=dataset, oversample_thr=1e-3, lazy_init=True) + + cls_banlanced_datasets.full_init() + cls_banlanced_datasets.repeat_indices = repeat_indices + assert len(cls_banlanced_datasets) == len(repeat_indices) + cls_banlanced_datasets.full_init() + cls_banlanced_datasets._fully_initialized = False + cls_banlanced_datasets[1] + cls_banlanced_datasets.repeat_indices = repeat_indices + assert len(cls_banlanced_datasets) == len(repeat_indices) + def test_meta(self): + self._init_dataset() assert self.cls_banlanced_datasets.meta == self.dataset.meta def test_length(self): + self._init_dataset() assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) def test_getitem(self): + self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets[i] == self.dataset[ self.repeat_indices[i]] def test_get_data_info(self): + self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets.get_data_info( i) == self.dataset.get_data_info(self.repeat_indices[i]) def test_get_cat_ids(self): + self._init_dataset() for i in range(len(self.repeat_indices)): assert self.cls_banlanced_datasets.get_cat_ids( i) == self.dataset.get_cat_ids(self.repeat_indices[i]) -- GitLab