diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index 0fdfbd0f95c4629dcb46a02889b95a482f979212..179d003836a7a9c7a03f653106a66ca5452cddf5 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -25,7 +25,10 @@ class Compose: def __init__(self, transforms: Sequence[Union[dict, Callable]]): self.transforms: List[Callable] = [] + for transform in transforms: + # `Compose` can be built with config dict with type and + # corresponding arguments. if isinstance(transform, dict): transform = TRANSFORMS.build(transform) if not callable(transform): @@ -50,6 +53,10 @@ class Compose: """ for t in self.transforms: data = t(data) + # The transform will return None when it failed to load images or + # cannot find suitable augmentation parameters to augment the data. + # Here we simply return None if the transform returns None and the + # dataset will handle it by randomly selecting another data sample. if data is None: return None return data @@ -79,12 +86,16 @@ def force_full_init(old_func: Callable) -> Any: Returns: Any: Depends on old_func. """ - # TODO This decorator can also be used in runner. + @functools.wraps(old_func) def wrapper(obj: object, *args, **kwargs): + # The instance must have `full_init` method. if not hasattr(obj, 'full_init'): raise AttributeError(f'{type(obj)} does not have full_init ' 'method.') + # If instance does not have `_fully_initialized` attribute or + # `_fully_initialized` is False, call `full_init` and set + # `_fully_initialized` to True if not getattr(obj, '_fully_initialized', False): warnings.warn('Attribute `_fully_initialized` is not defined in ' f'{type(obj)} or `type(obj)._fully_initialized is ' @@ -139,16 +150,16 @@ class BaseDataset(Dataset): Args: ann_file (str): Annotation file path. - meta (dict, optional): Meta information for dataset, such as class + metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Defaults to None. data_prefix (dict, optional): Prefix for training data. Defaults to dict(img=None, ann=None). filter_cfg (dict, optional): Config for filter data. Defaults to None. - num_samples (int, optional): Support using first few data in - annotation file to facilitate training/testing on a smaller - dataset. Defaults to -1 which means using all ``data_infos``. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. serialize_data (bool, optional): Whether to hold memory using serialized objects, when enabled, data loader workers can use shared RAM from master process instead of making a copy. Defaults @@ -161,26 +172,48 @@ class BaseDataset(Dataset): information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=False``. Defaults to False. - max_refetch (int, optional): The maximum number of cycles to get a - valid image. Defaults to 1000. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. Note: BaseDataset collects meta information from `annotation file` (the - lowest priority), ``BaseDataset.META``(medium) and `meta parameter` - (highest) passed to constructors. The lower priority meta information - will be overwritten by higher one. + lowest priority), ``BaseDataset.METAINFO``(medium) and `metainfo + parameter` (highest) passed to constructors. The lower priority meta + information will be overwritten by higher one. + + Note: + Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. + should not inherit from ``BaseDataset`` since ``get_subset`` and + ``get_subset_`` could produce ambiguous meaning sub-dataset which + conflicts with original dataset. + + Examples: + Assume the annotation file is given above. + >>> class CustomDataset(BaseDataset): + >>> METAINFO: dict = dict(task_name='custom_task', + >>> dataset_type='custom_type') + >>> metainfo=dict(task_name='custom_task_name') + >>> custom_dataset = CustomDataset( + >>> 'path/to/ann_file', + >>> metainfo=metainfo) + >>> # meta information of annotation file will be overwritten by + >>> # `CustomDataset.METAINFO`. The merged meta information will + >>> # further be overwritten by argument `metainfo`. + >>> custom_dataset.metainfo + {'task_name': custom_task_name, dataset_type: custom_type} """ - META: dict = dict() + METAINFO: dict = dict() _fully_initialized: bool = False def __init__(self, ann_file: str, - meta: Optional[dict] = None, + metainfo: Optional[dict] = None, data_root: Optional[str] = None, data_prefix: dict = dict(img=None, ann=None), filter_cfg: Optional[dict] = None, - num_samples: int = -1, + indices: Optional[Union[int, Sequence[int]]] = None, serialize_data: bool = True, pipeline: List[Union[dict, Callable]] = [], test_mode: bool = False, @@ -191,23 +224,23 @@ class BaseDataset(Dataset): self.data_prefix = copy.copy(data_prefix) self.ann_file = ann_file self.filter_cfg = copy.deepcopy(filter_cfg) - self._num_samples = num_samples + self._indices = indices self.serialize_data = serialize_data self.test_mode = test_mode self.max_refetch = max_refetch - self.data_infos: List[dict] = [] - self.data_infos_bytes: bytearray = bytearray() + self.data_list: List[dict] = [] + self.date_bytes: np.ndarray - # set meta information - self._meta = self._get_meta_data(copy.deepcopy(meta)) + # Set meta information. + self._metainfo = self._get_meta_info(copy.deepcopy(metainfo)) - # join paths + # Join paths. if self.data_root is not None: self._join_prefix() - # build pipeline + # Build pipeline. self.pipeline = Compose(pipeline) - + # Full initialize the dataset. if not lazy_init: self.full_init() @@ -225,11 +258,13 @@ class BaseDataset(Dataset): if self.serialize_data: start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() end_addr = self.data_address[idx].item() - bytes = memoryview(self.data_infos_bytes[start_addr:end_addr]) - data_info = pickle.loads(bytes) + bytes = memoryview( + self.date_bytes[start_addr:end_addr]) # type: ignore + data_info = pickle.loads(bytes) # type: ignore else: - data_info = self.data_infos[idx] - # To record the real positive index of data information. + data_info = self.data_list[idx] + # Some codebase needs `sample_idx` of data information. Here we convert + # the idx to a positive number and save it in data information. if idx >= 0: data_info['sample_idx'] = idx else: @@ -248,69 +283,66 @@ class BaseDataset(Dataset): Several steps to initialize annotation: - - load_annotations: Load annotations from annotation file. + - load_data_list: Load annotations from annotation file. - filter data information: Filter annotations according to - filter_cfg. - - slice_data: Slice dataset according to ``self._num_samples`` - - serialize_data: Serialize ``self.data_infos`` if + filter_cfg. + - slice_data: Slice dataset according to ``self._indices`` + - serialize_data: Serialize ``self.data_list`` if ``self.serialize_data`` is True. """ if self._fully_initialized: return - # load data information - self.data_infos = self.load_annotations(self.ann_file) + self.data_list = self.load_data_list(self.ann_file) # filter illegal data, such as data that has no annotations. - self.data_infos = self.filter_data() - # if `num_samples > 0`, return the first `num_samples` data information - self.data_infos = self._slice_data() - # serialize data_infos + self.data_list = self.filter_data() + # Get subset data according to indices. + if self._indices is not None: + self.data_list = self._get_unserialized_subset(self._indices) + + # serialize data_list if self.serialize_data: - self.data_infos_bytes, self.data_address = self._serialize_data() - # Empty cache for preventing making multiple copies of - # `self.data_info` when loading data multi-processes. - self.data_infos.clear() - gc.collect() + self.date_bytes, self.data_address = self._serialize_data() + self._fully_initialized = True @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get meta information of dataset. Returns: - dict: meta information collected from ``BaseDataset.META``, - annotation file and meta parameter during instantiation. + dict: meta information collected from ``BaseDataset.METAINFO``, + annotation file and metainfo argument during instantiation. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) - def parse_annotations(self, - raw_data_info: dict) -> Union[dict, List[dict]]: + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. - ``parse_annotations`` should return ``dict`` or ``List[dict]``. Each - dict contains the annotations of a training sample. If the protocol of + This method should return dict or list of dict. Each dict or list + contains the data information of a training sample. If the protocol of the sample annotations is changed, this function can be overridden to update the parsing logic while keeping compatibility. Args: - raw_data_info (dict): Raw annotation load from ``ann_file`` + raw_data_info (dict): Raw data information load from ``ann_file`` Returns: - Union[dict, List[dict]]: Parsed annotation. + list or list[dict]: Parsed annotation. """ return raw_data_info def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Defaults return all - ``data_infos``. + ``data_list``. - If some ``data_infos`` could be filtered according to specific logic, + If some ``data_list`` could be filtered according to specific logic, the subclass should override this method. Returns: - List[dict]: Filtered results. + list[int]: Filtered results. """ - return self.data_infos + return self.data_list def get_cat_ids(self, idx: int) -> List[int]: """Get category ids by index. Dataset wrapped by ClassBalancedDataset @@ -323,26 +355,34 @@ class BaseDataset(Dataset): idx (int): The index of data. Returns: - List[int]: All categories in the image of specified index. + list[int]: All categories in the image of specified index. """ raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' 'method') def __getitem__(self, idx: int) -> dict: - """Get the idx-th image of dataset after ``self.pipelines`` and - ``full_init`` will be called if the dataset has not been fully - initialized. + """Get the idx-th image and data information of dataset after + ``self.pipeline``, and ``full_init`` will be called if the dataset has + not been fully initialized. - During training phase, if ``self.pipelines`` get ``None``, + During training phase, if ``self.pipeline`` get ``None``, ``self._rand_another`` will be called until a valid image is fetched or the maximum limit of refetech is reached. Args: - idx (int): The index of self.data_infos + idx (int): The index of self.data_list. Returns: - dict: The idx-th image of dataset after ``self.pipelines``. + dict: The idx-th image and data information of dataset after + ``self.pipeline``. """ + # Performing full initialization by calling `__getitem__` will consume + # extra memory. If a dataset is not fully initialized by setting + # `lazy_init=True` and then fed into the dataloader. Different workers + # will simultaneously read and parse the annotation. It will cost more + # time and memory, although this may work. Therefore, it is recommended + # to manually call `full_init` before dataset fed into dataloader to + # ensure all workers use shared RAM from master process. if not self._fully_initialized: warnings.warn( 'Please call `full_init()` method manually to accelerate ' @@ -350,31 +390,39 @@ class BaseDataset(Dataset): self.full_init() if self.test_mode: - return self.prepare_data(idx) - - for _ in range(self.max_refetch): - data_sample = self.prepare_data(idx) - if data_sample is None: + data = self.prepare_data(idx) + if data is None: + raise Exception('Test time pipline should not get `None` ' + 'data_sample') + return data + + for _ in range(self.max_refetch + 1): + data = self.prepare_data(idx) + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: idx = self._rand_another() continue - return data_sample + return data raise Exception(f'Cannot find valid image after {self.max_refetch}! ' - 'Please check your image path and pipelines') + 'Please check your image path and pipeline') - def load_annotations(self, ann_file: str) -> List[dict]: + def load_data_list(self, ann_file: str) -> List[dict]: """Load annotations from an annotation file. If the annotation file does not follow `OpenMMLab 2.0 format dataset <https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ . - The subclass must override this method for load annotations. + The subclass must override this method for load annotations. The meta + information of annotation file will be overwritten :attr:`METAINFO` + and ``metainfo`` argument of constructor. Args: ann_file (str): Absolute annotation file path if ``self.root=None`` or relative path if ``self.root=/path/to/data/``. Returns: - List[dict]: A list of annotation. + list[dict]: A list of annotation. """ # noqa: E501 check_file_exist(ann_file) annotations = load(ann_file) @@ -387,59 +435,66 @@ class BaseDataset(Dataset): meta_data = annotations['metadata'] raw_data_infos = annotations['data_infos'] - # update self._meta + # Meta information load from annotation file will not influence the + # existed meta information load from `BaseDataset.METAINFO` and + # `metainfo` arguments defined in constructor. for k, v in meta_data.items(): - # We only merge keys that are not contained in self._meta. - self._meta.setdefault(k, v) + self._metainfo.setdefault(k, v) - # load and parse data_infos - data_infos = [] + # load and parse data_infos. + data_list = [] for raw_data_info in raw_data_infos: - data_info = self.parse_annotations(raw_data_info) + # parse raw data information to target format + data_info = self.parse_data_info(raw_data_info) if isinstance(data_info, dict): - data_infos.append(data_info) + # For image tasks, `data_info` should information if single + # image, such as dict(img_path='xxx', width=360, ...) + data_list.append(data_info) elif isinstance(data_info, list): - # data_info can also be a list of dict, which means one - # data_info contains multiple samples. + # For video tasks, `data_info` could contain image + # information of multiple frames, such as + # [dict(video_path='xxx', timestamps=...), + # dict(video_path='xxx', timestamps=...)] for item in data_info: if not isinstance(item, dict): - raise TypeError('data_info must be a list[dict], but ' + raise TypeError('data_info must be list of dict, but ' f'got {type(item)}') - data_infos.extend(data_info) + data_list.extend(data_info) else: - raise TypeError('data_info should be a dict or list[dict], ' + raise TypeError('data_info should be a dict or list of dict, ' f'but got {type(data_info)}') - return data_infos + return data_list @classmethod - def _get_meta_data(cls, in_meta: dict = None) -> dict: + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: """Collect meta information from the dictionary of meta. Args: - in_meta (dict): Meta information dict. If ``in_meta`` contains - existed filename, it will be parsed by ``list_from_file``. + in_metainfo (dict): Meta information dict. If ``in_metainfo`` + contains existed filename, it will be parsed by + ``list_from_file``. Returns: dict: Parsed meta information. """ - # cls.META will be overwritten by in_meta - cls_meta = copy.deepcopy(cls.META) - if in_meta is None: - return cls_meta - if not isinstance(in_meta, dict): + # `cls.METAINFO` will be overwritten by in_meta + cls_metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return cls_metainfo + if not isinstance(in_metainfo, dict): raise TypeError( - f'in_meta should be a dict, but got {type(in_meta)}') + f'in_metainfo should be a dict, but got {type(in_metainfo)}') - for k, v in in_meta.items(): + for k, v in in_metainfo.items(): if isinstance(v, str) and osp.isfile(v): - # if filename in in_meta, this key will be further parsed. - # nested filename will be ignored.: - cls_meta[k] = list_from_file(v) + # if filename in in_metainfo, this key will be further parsed. + # nested filename will be ignored. + cls_metainfo[k] = list_from_file(v) else: - cls_meta[k] = v + cls_metainfo[k] = v - return cls_meta + return cls_metainfo def _join_prefix(self): """Join ``self.data_root`` with ``self.data_prefix`` and @@ -465,9 +520,12 @@ class BaseDataset(Dataset): >>> self.ann_file 'a/b/c/f' """ + # Automatically join annotation file path with `self.root` if + # `self.ann_file` is not an absolute path. if not osp.isabs(self.ann_file): self.ann_file = osp.join(self.data_root, self.ann_file) - + # Automatically join data directory with `self.root` if path value in + # `self.data_prefix` is not an absolute path. for data_key, prefix in self.data_prefix.items(): if prefix is None: self.data_prefix[data_key] = self.data_root @@ -479,31 +537,208 @@ class BaseDataset(Dataset): raise TypeError('prefix should be a string or None, but got ' f'{type(prefix)}') - def _slice_data(self) -> List[dict]: - """Slice ``self.data_infos``. BaseDataset supports only using the first - few data. + @force_full_init + def get_subset_(self, indices: Union[Sequence[int], int]) -> None: + """The in-place version of ``get_subset `` to convert dataset to a + subset of original dataset. + + This method will convert the original dataset to a subset of dataset. + If type of indices is int, ``get_subset_`` will return a subdataset + which contains the first or last few data information according to + indices is positive or negative. If type of indices is a sequence of + int, the subdataset will extract the data information according to + the index given in indices. + + Examples: + >>> dataset = BaseDataset('path/to/ann_file') + >>> len(dataset) + 100 + >>> dataset.get_subset_(90) + >>> len(dataset) + 90 + >>> # if type of indices is sequence, extract the corresponding + >>> # index data information + >>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> len(dataset) + 10 + >>> dataset.get_subset_(-3) + >>> len(dataset) # Get the latest few data information. + 3 + + Args: + indices (int or Sequence[int]): If type of indices is int, indices + represents the first or last few data of dataset according to + indices is positive or negative. If type of indices is + Sequence, indices represents the target data information + index of dataset. + """ + # Get subset of data from serialized data or data information sequence + # according to `self.serialize_data`. + if self.serialize_data: + self.date_bytes, self.data_address = \ + self._get_serialized_subset(indices) + else: + self.data_list = self._get_unserialized_subset(indices) + + @force_full_init + def get_subset(self, indices: Union[Sequence[int], int]) -> 'BaseDataset': + """Return a subset of dataset. + + This method will return a subset of original dataset. If type of + indices is int, ``get_subset_`` will return a subdataset which + contains the first or last few data information according to + indices is positive or negative. If type of indices is a sequence of + int, the subdataset will extract the information according to the index + given in indices. + + Examples: + >>> dataset = BaseDataset('path/to/ann_file') + >>> len(dataset) + 100 + >>> subdataset = dataset.get_subset(90) + >>> len(sub_dataset) + 90 + >>> # if type of indices is list, extract the corresponding + >>> # index data information + >>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, + >>> 8, 9]) + >>> len(sub_dataset) + 10 + >>> subdataset = dataset.get_subset(-3) + >>> len(subdataset) # Get the latest few data information. + 3 + + Args: + indices (int or Sequence[int]): If type of indices is int, indices + represents the first or last few data of dataset according to + indices is positive or negative. If type of indices is + Sequence, indices represents the target data information + index of dataset. + + Returns: + BaseDataset: A subset of dataset. + """ + # Get subset of data from serialized data or data information list + # according to `self.serialize_data`. Since `_get_serialized_subset` + # will recalculate the subset data information, + # `_copy_without_annotation` will copy all attributes except data + # information. + sub_dataset = self._copy_without_annotation() + # Get subset of dataset with serialize and unserialized data. + if self.serialize_data: + date_bytes, data_address = \ + self._get_serialized_subset(indices) + sub_dataset.date_bytes = date_bytes.copy() + sub_dataset.data_address = data_address.copy() + else: + data_list = self._get_unserialized_subset(indices) + sub_dataset.data_list = copy.deepcopy(data_list) + return sub_dataset + + def _get_serialized_subset(self, indices: Union[Sequence[int], int]) \ + -> Tuple[np.ndarray, np.ndarray]: + """Get subset of serialized data information list. + + Args: + indices (int or Sequence[int]): If type of indices is int, + indices represents the first or last few data of serialized + data information list. If type of indices is Sequence, indices + represents the target data information index which consist of + subset data information. Returns: - List[dict]: A slice of ``self.data_infos`` + Tuple[np.ndarray, np.ndarray]: subset of serialized data + information. """ - assert self._num_samples < len(self.data_infos), \ - f'Slice size({self._num_samples}) is larger than dataset ' \ - f'size({self.data_infos}, please keep `num_sample` smaller than' \ - f'{self.data_infos})' - if self._num_samples > 0: - return self.data_infos[:self._num_samples] + sub_date_bytes: Union[List, np.ndarray] + sub_data_address: Union[List, np.ndarray] + if isinstance(indices, int): + if indices >= 0: + assert indices < len(self.data_address), \ + f'{indices} is out of dataset length({len(self)}' + # Return the first few data information. + end_addr = self.data_address[indices - 1].item() \ + if indices > 0 else 0 + # Slicing operation of `np.ndarray` does not trigger a memory + # copy. + sub_date_bytes = self.date_bytes[:end_addr] + # Since the buffer size of first few data information is not + # changed, + sub_data_address = self.data_address[:indices] + else: + assert -indices <= len(self.data_address), \ + f'{indices} is out of dataset length({len(self)}' + # Return the last few data information. + ignored_bytes_size = self.data_address[indices - 1] + start_addr = self.data_address[indices - 1].item() + sub_date_bytes = self.date_bytes[start_addr:] + sub_data_address = self.data_address[indices:] + sub_data_address = sub_data_address - ignored_bytes_size + elif isinstance(indices, Sequence): + sub_date_bytes = [] + sub_data_address = [] + for idx in indices: + assert len(self) > idx >= -len(self) + start_addr = 0 if idx == 0 else \ + self.data_address[idx - 1].item() + end_addr = self.data_address[idx].item() + # Get data information by address. + sub_date_bytes.append(self.date_bytes[start_addr:end_addr]) + # Get data information size. + sub_data_address.append(end_addr - start_addr) + # Handle indices is an empty list. + if sub_date_bytes: + sub_date_bytes = np.concatenate(sub_date_bytes) + sub_data_address = np.cumsum(sub_data_address) + else: + sub_date_bytes = np.array([]) + sub_data_address = np.array([]) else: - return self.data_infos + raise TypeError('indices should be a int or sequence of int, ' + f'but got {type(indices)}') + return sub_date_bytes, sub_data_address # type: ignore + + def _get_unserialized_subset(self, indices: Union[Sequence[int], + int]) -> list: + """Get subset of data information list. + + Args: + indices (int or Sequence[int]): If type of indices is int, + indices represents the first or last few data of data + information. If indices of indices is Sequence, indices + represents the target data information index which consist + of subset data information. + + Returns: + Tuple[np.ndarray, np.ndarray]: subset of data information. + """ + if isinstance(indices, int): + if indices >= 0: + # Return the first few data information. + sub_data_list = self.data_list[:indices] + else: + # Return the last few data information. + sub_data_list = self.data_list[indices:] + elif isinstance(indices, Sequence): + # Return the data information according to given indices. + subdata_list = [] + for idx in indices: + subdata_list.append(self.data_list[idx]) + sub_data_list = subdata_list + else: + raise TypeError('indices should be a int or sequence of int, ' + f'but got {type(indices)}') + return sub_data_list def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: - """Serialize ``self.data_infos`` to save memory when launching multiple + """Serialize ``self.data_list`` to save memory when launching multiple workers in data loading. This function will be called in ``full_init``. Hold memory using serialized objects, and data loader workers can use shared RAM from master process instead of making a copy. Returns: - Tuple[np.ndarray, np.ndarray]: serialize result and corresponding + Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding address. """ @@ -511,13 +746,19 @@ class BaseDataset(Dataset): buffer = pickle.dumps(data, protocol=4) return np.frombuffer(buffer, dtype=np.uint8) - serialized_data_infos_list = [_serialize(x) for x in self.data_infos] - address_list = np.asarray([len(x) for x in serialized_data_infos_list], - dtype=np.int64) + # Serialize data information list avoid making multiple copies of + # `self.data_list` when iterate `import torch.utils.data.dataloader` + # with multiple workers. + data_list = [_serialize(x) for x in self.data_list] + address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) data_address: np.ndarray = np.cumsum(address_list) - serialized_data_infos = np.concatenate(serialized_data_infos_list) - - return serialized_data_infos, data_address + # TODO Check if np.concatenate is necessary + data_bytes = np.concatenate(data_list) + # Empty cache for preventing making multiple copies of + # `self.data_info` when loading data multi-processes. + self.data_list.clear() + gc.collect() + return data_bytes, data_address def _rand_another(self) -> int: """Get random index. @@ -550,4 +791,24 @@ class BaseDataset(Dataset): if self.serialize_data: return len(self.data_address) else: - return len(self.data_infos) + return len(self.data_list) + + def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': + """Deepcopy for all attributes other than ``data_list``, + ``data_address`` and ``date_bytes``. + + Args: + memo: Memory dict which used to reconstruct complex object + correctly. + """ + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + if key in ['data_list', 'data_address', 'date_bytes']: + continue + super(BaseDataset, other).__setattr__(key, + copy.deepcopy(value, memo)) + + return other diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index ba8c83cfbb0f511295f022b10b2ad279fbcadc5b..454ec96d0df47902b88c1d1f6f2469e944075a70 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -4,7 +4,7 @@ import copy import math import warnings from collections import defaultdict -from typing import List, Sequence, Tuple +from typing import List, Sequence, Tuple, Union from torch.utils.data.dataset import ConcatDataset as _ConcatDataset @@ -16,6 +16,13 @@ class ConcatDataset(_ConcatDataset): Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. + Note: + ``ConcatDataset`` should not inherit from ``BaseDataset`` since + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning + sub-dataset which conflicts with original dataset. If you want to use + a sub-dataset of ``ConcatDataset``, you should set ``indices`` + arguments for wrapped dataset which inherit from ``BaseDataset``. + Args: datasets (Sequence[BaseDataset]): A list of datasets which will be concatenated. @@ -26,12 +33,12 @@ class ConcatDataset(_ConcatDataset): def __init__(self, datasets: Sequence[BaseDataset], lazy_init: bool = False): - # Only use meta of first dataset. - self._meta = datasets[0].meta + # Only use metainfo of first dataset. + self._metainfo = datasets[0].metainfo self.datasets = datasets # type: ignore for i, dataset in enumerate(datasets, 1): - if self._meta != dataset.meta: - warnings.warn( + if self._metainfo != dataset.metainfo: + raise ValueError( f'The meta information of the {i}-th dataset does not ' 'match meta information of the first dataset') @@ -40,14 +47,14 @@ class ConcatDataset(_ConcatDataset): self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the first dataset in ``self.datasets``. Returns: dict: Meta information of first dataset. """ - # Prevent `self._meta` from being modified by outside. - return copy.deepcopy(self._meta) + # Prevent `self._metainfo` from being modified by outside. + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -77,8 +84,9 @@ class ConcatDataset(_ConcatDataset): f'absolute value of index({idx}) should not exceed dataset' f'length({len(self)}).') idx = len(self) + idx - # Get the inner index of single dataset + # Get `dataset_idx` to tell idx belongs to which dataset. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + # Get the inner index of single dataset. if dataset_idx == 0: sample_idx = idx else: @@ -111,6 +119,26 @@ class ConcatDataset(_ConcatDataset): dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) return self.datasets[dataset_idx][sample_idx] + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`ConcatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ConcatDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`ConcatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ConcatDataset`.') + class RepeatDataset: """A wrapper of repeated dataset. @@ -120,10 +148,17 @@ class RepeatDataset: is small. Using RepeatDataset can reduce the data loading time between epochs. + Note: + ``RepeatDataset`` should not inherit from ``BaseDataset`` since + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning + sub-dataset which conflicts with original dataset. If you want to use + a sub-dataset of ``RepeatDataset``, you should set ``indices`` + arguments for wrapped dataset which inherit from ``BaseDataset``. + Args: dataset (BaseDataset): The dataset to be repeated. times (int): Repeat times. - lazy_init (bool, optional): Whether to load annotation during + lazy_init (bool): Whether to load annotation during instantiation. Defaults to False. """ @@ -133,20 +168,20 @@ class RepeatDataset: lazy_init: bool = False): self.dataset = dataset self.times = times - self._meta = dataset.meta + self._metainfo = dataset.metainfo self._fully_initialized = False if not lazy_init: self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the repeated dataset. Returns: dict: The meta information of repeated dataset. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -195,6 +230,26 @@ class RepeatDataset: def __len__(self): return self.times * self._ori_len + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`RepeatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `RepeatDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`RepeatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `RepeatDataset`.') + class ClassBalancedDataset: """A wrapper of class balanced dataset. @@ -219,6 +274,14 @@ class ClassBalancedDataset: 3. For each image I, compute the image-level repeat factor: :math:`r(I) = max_{c in I} r(c)` + Note: + ``ClassBalancedDataset`` should not inherit from ``BaseDataset`` + since ``get_subset`` and ``get_subset_`` could produce ambiguous + meaning sub-dataset which conflicts with original dataset. If you + want to use a sub-dataset of ``ClassBalancedDataset``, you should set + ``indices`` arguments for wrapped dataset which inherit from + ``BaseDataset``. + Args: dataset (BaseDataset): The dataset to be repeated. oversample_thr (float): frequency threshold below which data is @@ -236,20 +299,20 @@ class ClassBalancedDataset: lazy_init: bool = False): self.dataset = dataset self.oversample_thr = oversample_thr - self._meta = dataset.meta + self._metainfo = dataset.metainfo self._fully_initialized = False if not lazy_init: self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the repeated dataset. Returns: dict: The meta information of repeated dataset. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -257,9 +320,12 @@ class ClassBalancedDataset: return self.dataset.full_init() - + # Get repeat factors for each image. repeat_factors = self._get_repeat_factors(self.dataset, self.oversample_thr) + # Repeat dataset's indices according to repeat_factors. For example, + # if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`, + # the repeated indices will be [1, 2, 2, 3, 3, 3]. repeat_indices = [] for dataset_index, repeat_factor in enumerate(repeat_factors): repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) @@ -362,3 +428,23 @@ class ClassBalancedDataset: @force_full_init def __len__(self): return len(self.repeat_indices) + + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') diff --git a/tests/data/annotations/dummy_annotation.json b/tests/data/annotations/dummy_annotation.json index d36109ce731d86e607bcb73925778101b0ffa07d..5fac907e8079cdb305535733c36435460cd0e877 100644 --- a/tests/data/annotations/dummy_annotation.json +++ b/tests/data/annotations/dummy_annotation.json @@ -46,6 +46,26 @@ "extra_anns": [4,5,6] } ] + }, + { + "img_path": "gray.jpg", + "height": 512, + "width": 512, + "instances": + [ + { + "bbox": [0, 0, 10, 20], + "bbox_label": 1, + "mask": [[0,0],[0,10],[10,20],[20,0]], + "extra_anns": [1,2,3] + }, + { + "bbox": [10, 10, 110, 120], + "bbox_label": 2, + "mask": [[10,10],[10,110],[110,120],[120,10]], + "extra_anns": [4,5,6] + } + ] } ] } diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index e695488c53dd51d967031b14e8bfceac7c229fe0..3dbbd91917d14d38c8e61661e48f91b4a2f124b6 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -33,12 +33,12 @@ class TestBaseDataset: filename='test_img.jpg', height=604, width=640, sample_idx=0) imgs = torch.rand((2, 3, 32, 32)) pipeline = MagicMock(return_value=dict(imgs=imgs)) - META: dict = dict() - parse_annotations = MagicMock(return_value=data_info) + METAINFO: dict = dict() + parse_data_info = MagicMock(return_value=data_info) def _init_dataset(self): - self.dataset_type.META = self.META - self.dataset_type.parse_annotations = self.parse_annotations + self.dataset_type.METAINFO = self.METAINFO + self.dataset_type.parse_data_info = self.parse_data_info def test_init(self): self._init_dataset() @@ -48,14 +48,14 @@ class TestBaseDataset: data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') # test the instantiation of self.base_dataset with @@ -66,9 +66,9 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json', serialize_data=False) assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert not hasattr(dataset, 'data_address') - assert len(dataset) == 2 + assert len(dataset) == 3 assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init @@ -78,7 +78,7 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json', lazy_init=True) assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # test the instantiation of self.base_dataset if ann_file is not # existed. @@ -106,9 +106,9 @@ class TestBaseDataset: data_prefix=dict(img=['img']), ann_file='annotations/annotation_wrong_format.json') - # test the instantiation of self.base_dataset when `parse_annotations` + # test the instantiation of self.base_dataset when `parse_data_info` # return `list[dict]` - self.dataset_type.parse_annotations = MagicMock( + self.dataset_type.parse_data_info = MagicMock( return_value=[self.data_info, self.data_info.copy()]) dataset = self.dataset_type( @@ -117,136 +117,137 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json') dataset.pipeline = self.pipeline assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') - assert len(dataset) == 4 + assert len(dataset) == 6 assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info - # test the instantiation of self.base_dataset when `parse_annotations` + # test the instantiation of self.base_dataset when `parse_data_info` # return unsupported data. with pytest.raises(TypeError): - self.dataset_type.parse_annotations = MagicMock(return_value='xxx') + self.dataset_type.parse_data_info = MagicMock(return_value='xxx') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') with pytest.raises(TypeError): - self.dataset_type.parse_annotations = MagicMock( + self.dataset_type.parse_data_info = MagicMock( return_value=[self.data_info, 'xxx']) - dataset = self.dataset_type( + self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') def test_meta(self): self._init_dataset() - # test dataset.meta with setting the meta from annotation file as the - # meta of self.base_dataset + # test dataset.metainfo with setting the metainfo from annotation file + # as the metainfo of self.base_dataset. dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type='test_dataset', task_name='test_task', empty_list=[]) - # test dataset.meta with setting META in self.base_dataset + # test dataset.metainfo with setting METAINFO in self.base_dataset dataset_type = 'new_dataset' - self.dataset_type.META = dict( + self.dataset_type.METAINFO = dict( dataset_type=dataset_type, classes=('dog', 'cat')) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='test_task', classes=('dog', 'cat'), empty_list=[]) - # test dataset.meta with passing meta into self.base_dataset - meta = dict(classes=('dog', ), task_name='new_task') + # test dataset.metainfo with passing metainfo into self.base_dataset + metainfo = dict(classes=('dog', ), task_name='new_task') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) - assert self.dataset_type.META == dict( + metainfo=metainfo) + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='new_task', classes=('dog', ), empty_list=[]) - # reset `base_dataset.META`, the `dataset.meta` should not change - self.dataset_type.META['classes'] = ('dog', 'cat', 'fish') - assert self.dataset_type.META == dict( + # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not + # change + self.dataset_type.METAINFO['classes'] = ('dog', 'cat', 'fish') + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='new_task', classes=('dog', ), empty_list=[]) - # test dataset.meta with passing meta containing a file into + # test dataset.metainfo with passing metainfo containing a file into # self.base_dataset - meta = dict( + metainfo = dict( classes=osp.join( osp.dirname(__file__), '../data/meta/classes.txt')) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) - assert dataset.meta == dict( + metainfo=metainfo) + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='test_task', classes=['dog'], empty_list=[]) - # test dataset.meta with passing unsupported meta into + # test dataset.metainfo with passing unsupported metainfo into # self.base_dataset with pytest.raises(TypeError): - meta = 'dog' + metainfo = 'dog' dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) + metainfo=metainfo) - # test dataset.meta with passing meta into self.base_dataset and - # lazy_init is True - meta = dict(classes=('dog', )) + # test dataset.metainfo with passing metainfo into self.base_dataset + # and lazy_init is True + metainfo = dict(classes=('dog', )) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta, + metainfo=metainfo, lazy_init=True) - # 'task_name' and 'empty_list' not in dataset.meta - assert dataset.meta == dict( + # 'task_name' and 'empty_list' not in dataset.metainfo + assert dataset.metainfo == dict( dataset_type=dataset_type, classes=('dog', )) - # test whether self.base_dataset.META is changed when a customize + # test whether self.base_dataset.METAINFO is changed when a customize # dataset inherit self.base_dataset - # test reset META in ToyDataset. + # test reset METAINFO in ToyDataset. class ToyDataset(self.dataset_type): - META = dict(xxx='xxx') + METAINFO = dict(xxx='xxx') - assert ToyDataset.META == dict(xxx='xxx') - assert self.dataset_type.META == dict( + assert ToyDataset.METAINFO == dict(xxx='xxx') + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - # test update META in ToyDataset. + # test update METAINFO in ToyDataset. class ToyDataset(self.dataset_type): - META = copy.deepcopy(self.dataset_type.META) - META['classes'] = ('bird', ) + METAINFO = copy.deepcopy(self.dataset_type.METAINFO) + METAINFO['classes'] = ('bird', ) - assert ToyDataset.META == dict( + assert ToyDataset.METAINFO == dict( dataset_type=dataset_type, classes=('bird', )) - assert self.dataset_type.META == dict( + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) @pytest.mark.parametrize('lazy_init', [True, False]) @@ -258,16 +259,16 @@ class TestBaseDataset: lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 else: # test `__len__()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # call `full_init()` automatically - assert len(dataset) == 2 + assert len(dataset) == 3 assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') def test_compose(self): # test callable transform @@ -308,30 +309,41 @@ class TestBaseDataset: dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert dataset[0] == dict(imgs=self.imgs) else: - # test `__getitem__()` when lazy_init is True + # Test `__getitem__()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos - # call `full_init()` automatically + assert not dataset.data_list + # Call `full_init()` automatically assert dataset[0] == dict(imgs=self.imgs) assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') - # test with test mode - dataset.test_mode = True + # Test with test mode + dataset.test_mode = False assert dataset[0] == dict(imgs=self.imgs) + # Test cannot get a valid image. + dataset.prepare_data = MagicMock(return_value=None) + with pytest.raises(Exception): + dataset[0] + # Test get valid image by `_rand_another` - pipeline = MagicMock(return_value=None) - dataset.pipeline = pipeline - # test cannot get a valid image. - dataset.test_mode = False + def fake_prepare_data(idx): + if idx == 0: + return None + else: + return 1 + + dataset.prepare_data = fake_prepare_data + dataset[0] + dataset.test_mode = True with pytest.raises(Exception): dataset[0] @pytest.mark.parametrize('lazy_init', [True, False]) def test_get_data_info(self, lazy_init): + self._init_dataset() dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), @@ -340,16 +352,16 @@ class TestBaseDataset: if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert dataset.get_data_info(0) == self.data_info else: # test `get_data_info()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # call `full_init()` automatically assert dataset.get_data_info(0) == self.data_info assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') def test_force_full_init(self): with pytest.raises(AttributeError): @@ -363,40 +375,178 @@ class TestBaseDataset: class_without_full_init = ClassWithoutFullInit() class_without_full_init.foo() - @pytest.mark.parametrize('lazy_init', [True, False]) - def test_full_init(self, lazy_init): + def test_full_init(self): + self._init_dataset() dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + lazy_init=True) dataset.pipeline = self.pipeline - if not lazy_init: - assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 - assert dataset[0] == dict(imgs=self.imgs) - assert dataset.get_data_info(0) == self.data_info - else: - # test `full_init()` when lazy_init is True - assert not dataset._fully_initialized - assert not dataset.data_infos - # call `full_init()` manually - dataset.full_init() - assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 - assert dataset[0] == dict(imgs=self.imgs) - assert dataset.get_data_info(0) == self.data_info + # test `full_init()` when lazy_init is True + assert not dataset._fully_initialized + assert not dataset.data_list + # call `full_init()` manually + dataset.full_init() + assert dataset._fully_initialized + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 + assert dataset[0] == dict(imgs=self.imgs) + assert dataset.get_data_info(0) == self.data_info - def test_slice_data(self): - # test the instantiation of self.base_dataset when passing num_samples + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=False) + + dataset.pipeline = self.pipeline + assert dataset._fully_initialized + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 + assert dataset[0] == dict(imgs=self.imgs) + assert dataset.get_data_info(0) == self.data_info + + # test the instantiation of self.base_dataset when passing indices + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json') + dataset_sliced = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + indices=1) + assert dataset_sliced[0] == dataset[0] + assert len(dataset_sliced) == 1 + + @pytest.mark.parametrize( + 'lazy_init, serialize_data', + ([True, False], [False, True], [True, True], [False, False])) + def test_get_subset_(self, lazy_init, serialize_data): + # Test positive int indices. + indices = 2 + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) + + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 2 + for i in range(len(dataset_copy)): + ori_data = dataset[i] + assert dataset_copy[i] == ori_data + + # Test negative int indices. + indices = -2 + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 2 + for i in range(len(dataset_copy)): + ori_data = dataset[i + 1] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + + # If indices is 0, return empty dataset. + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(0) + assert len(dataset_copy) == 0 + + # Test list indices with positive element. + indices = [1] + dataset_copy = copy.deepcopy(dataset) + ori_data = dataset[1] + ori_data['sample_idx'] = 0 + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 1 + assert dataset_copy[0] == ori_data + + # Test list indices with negative element. + indices = [-1] + dataset_copy = copy.deepcopy(dataset) + ori_data = dataset[2] + ori_data['sample_idx'] = 0 + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 1 + assert dataset_copy[0] == ori_data + + # Test empty list. + indices = [] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 0 + # Test list with multiple positive indices. + indices = [0, 1, 2] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + for i in range(len(dataset_copy)): + ori_data = dataset[i] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + # Test list with multiple negative indices. + indices = [-1, -2, 0] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + for i in range(len(dataset_copy)): + ori_data = dataset[len(dataset) - i - 1] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + + with pytest.raises(TypeError): + dataset.get_subset_(dict()) + + @pytest.mark.parametrize( + 'lazy_init, serialize_data', + ([True, False], [False, True], [True, True], [False, False])) + def test_get_subset(self, lazy_init, serialize_data): + # Test positive indices. + indices = 2 dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json', - num_samples=1) - assert len(dataset) == 1 + lazy_init=lazy_init, + serialize_data=serialize_data) + dataset_sliced = dataset.get_subset(indices) + assert len(dataset_sliced) == 2 + assert dataset_sliced[0] == dataset[0] + for i in range(len(dataset_sliced)): + assert dataset_sliced[i] == dataset[i] + # Test negative indices. + indices = -2 + dataset_sliced = dataset.get_subset(indices) + assert len(dataset_sliced) == 2 + for i in range(len(dataset_sliced)): + ori_data = dataset[i + 1] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data + # If indices is 0 or empty list, return empty dataset. + assert len(dataset.get_subset(0)) == 0 + assert len(dataset.get_subset([])) == 0 + # test list indices. + indices = [1] + dataset_sliced = dataset.get_subset(indices) + ori_data = dataset[1] + ori_data['sample_idx'] = 0 + assert len(dataset_sliced) == 1 + assert dataset_sliced[0] == ori_data + # Test list with multiple positive index. + indices = [0, 1, 2] + dataset_sliced = dataset.get_subset(indices) + for i in range(len(dataset_sliced)): + ori_data = dataset[i] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data + # Test list with multiple negative index. + indices = [-1, -2, 0] + dataset_sliced = dataset.get_subset(indices) + for i in range(len(dataset_sliced)): + ori_data = dataset[len(dataset) - i - 1] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data def test_rand_another(self): # test the instantiation of self.base_dataset when passing num_samples @@ -404,7 +554,7 @@ class TestBaseDataset: data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json', - num_samples=1) + indices=1) assert dataset._rand_another() >= 0 assert dataset._rand_another() < len(dataset) @@ -416,7 +566,7 @@ class TestConcatDataset: # create dataset_a data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_a = dataset( @@ -427,58 +577,44 @@ class TestConcatDataset: # create dataset_b data_info = dict(filename='gray.jpg', height=288, width=512) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_b = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json', - meta=dict(classes=('dog', 'cat'))) + ann_file='annotations/dummy_annotation.json') self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init self.cat_datasets = ConcatDataset( datasets=[self.dataset_a, self.dataset_b]) def test_full_init(self): - dataset = BaseDataset - - # create dataset_a - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - - dataset_a = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) - - # create dataset_b - data_info = dict(filename='gray.jpg', height=288, width=512) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset_b = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json', - meta=dict(classes=('dog', 'cat'))) - dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() # test init with lazy_init=True - cat_datasets = ConcatDataset( - datasets=[dataset_a, dataset_b], lazy_init=True) - cat_datasets.full_init() - assert len(cat_datasets) == 4 - cat_datasets.full_init() - cat_datasets._fully_initialized = False - cat_datasets[1] - assert len(cat_datasets) == 4 + self.cat_datasets.full_init() + assert len(self.cat_datasets) == 6 + self.cat_datasets.full_init() + self.cat_datasets._fully_initialized = False + self.cat_datasets[1] + assert len(self.cat_datasets) == 6 + + with pytest.raises(NotImplementedError): + self.cat_datasets.get_subset_(1) + + with pytest.raises(NotImplementedError): + self.cat_datasets.get_subset(1) + # Different meta information will raise error. + with pytest.raises(ValueError): + dataset_b = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) + ConcatDataset(datasets=[self.dataset_a, dataset_b]) - def test_meta(self): + def test_metainfo(self): self._init_dataset() - assert self.cat_datasets.meta == self.dataset_a.meta - # meta of self.cat_datasets is from the first dataset when - # concatnating datasets with different metas. - assert self.cat_datasets.meta != self.dataset_b.meta + assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): self._init_dataset() @@ -524,7 +660,7 @@ class TestRepeatDataset: def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), @@ -538,31 +674,26 @@ class TestRepeatDataset: dataset=self.dataset, times=self.repeat_times) def test_full_init(self): - dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() - repeat_times = 5 - # test init - repeat_datasets = RepeatDataset( - dataset=dataset, times=repeat_times, lazy_init=True) + self.repeat_datasets.full_init() + assert len( + self.repeat_datasets) == self.repeat_times * len(self.dataset) + self.repeat_datasets.full_init() + self.repeat_datasets._fully_initialized = False + self.repeat_datasets[1] + assert len(self.repeat_datasets) == \ + self.repeat_times * len(self.dataset) - repeat_datasets.full_init() - assert len(repeat_datasets) == repeat_times * len(dataset) - repeat_datasets.full_init() - repeat_datasets._fully_initialized = False - repeat_datasets[1] - assert len(repeat_datasets) == repeat_times * len(dataset) + with pytest.raises(NotImplementedError): + self.repeat_datasets.get_subset_(1) - def test_meta(self): + with pytest.raises(NotImplementedError): + self.repeat_datasets.get_subset(1) + + def test_metainfo(self): self._init_dataset() - assert self.repeat_datasets.meta == self.dataset.meta + assert self.repeat_datasets.metainfo == self.dataset.metainfo def test_length(self): self._init_dataset() @@ -587,7 +718,7 @@ class TestClassBalancedDataset: def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) dataset.get_cat_ids = MagicMock(return_value=[0]) self.dataset = dataset( @@ -603,34 +734,26 @@ class TestClassBalancedDataset: self.cls_banlanced_datasets.repeat_indices = self.repeat_indices def test_full_init(self): - dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset.get_cat_ids = MagicMock(return_value=[0]) - dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() - repeat_indices = [0, 0, 1, 1, 1] - # test init - cls_banlanced_datasets = ClassBalancedDataset( - dataset=dataset, oversample_thr=1e-3, lazy_init=True) - - cls_banlanced_datasets.full_init() - cls_banlanced_datasets.repeat_indices = repeat_indices - assert len(cls_banlanced_datasets) == len(repeat_indices) - cls_banlanced_datasets.full_init() - cls_banlanced_datasets._fully_initialized = False - cls_banlanced_datasets[1] - cls_banlanced_datasets.repeat_indices = repeat_indices - assert len(cls_banlanced_datasets) == len(repeat_indices) + self.cls_banlanced_datasets.full_init() + self.cls_banlanced_datasets.repeat_indices = self.repeat_indices + assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) + self.cls_banlanced_datasets.full_init() + self.cls_banlanced_datasets._fully_initialized = False + self.cls_banlanced_datasets[1] + self.cls_banlanced_datasets.repeat_indices = self.repeat_indices + assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) - def test_meta(self): + with pytest.raises(NotImplementedError): + self.cls_banlanced_datasets.get_subset_(1) + + with pytest.raises(NotImplementedError): + self.cls_banlanced_datasets.get_subset(1) + + def test_metainfo(self): self._init_dataset() - assert self.cls_banlanced_datasets.meta == self.dataset.meta + assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo def test_length(self): self._init_dataset()