Skip to content
Snippets Groups Projects
test_base_dataset.py 32.3 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
from unittest.mock import MagicMock

import pytest
import torch

from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
                              ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import DATASETS, TRANSFORMS
def function_pipeline(data_info):
    return data_info


@TRANSFORMS.register_module()
class CallableTransform:

    def __call__(self, data_info):
        return data_info
@TRANSFORMS.register_module()
class NotCallableTransform:
    pass
@DATASETS.register_module()
class CustomDataset(BaseDataset):
    pass


class TestBaseDataset:
    dataset_type = BaseDataset
    data_info = dict(
        filename='test_img.jpg', height=604, width=640, sample_idx=0)
    imgs = torch.rand((2, 3, 32, 32))
    pipeline = MagicMock(return_value=dict(imgs=imgs))
    METAINFO: dict = dict()
    parse_data_info = MagicMock(return_value=data_info)

    def _init_dataset(self):
        self.dataset_type.METAINFO = self.METAINFO
        self.dataset_type.parse_data_info = self.parse_data_info

    def test_init(self):
        self._init_dataset()
        # test the instantiation of self.base_dataset
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert hasattr(dataset, 'data_address')
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json')
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert hasattr(dataset, 'data_address')

        # test the instantiation of self.base_dataset with
        # `serialize_data=False`
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            serialize_data=False)
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert not hasattr(dataset, 'data_address')
        assert dataset.get_data_info(0) == self.data_info

        # test the instantiation of self.base_dataset with lazy init
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=True)
        assert not dataset._fully_initialized

        # test the instantiation of self.base_dataset if ann_file is not
        # existed.
        with pytest.raises(FileNotFoundError):
            self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/not_existed_annotation.json')
        # Use the default value of ann_file, i.e., ''
        with pytest.raises(FileNotFoundError):
            self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'))

        # test the instantiation of self.base_dataset when the ann_file is
        # wrong
        with pytest.raises(ValueError):
            self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/annotation_wrong_keys.json')
        with pytest.raises(TypeError):
            self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/annotation_wrong_format.json')
        with pytest.raises(TypeError):
            self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img=['img']),
                ann_file='annotations/annotation_wrong_format.json')
        # test the instantiation of self.base_dataset when `parse_data_info`
        # return `list[dict]`
        self.dataset_type.parse_data_info = MagicMock(
            return_value=[self.data_info,
                          self.data_info.copy()])
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        dataset.pipeline = self.pipeline
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert hasattr(dataset, 'data_address')
        assert dataset[0] == dict(imgs=self.imgs)
        assert dataset.get_data_info(0) == self.data_info

        # test the instantiation of self.base_dataset when `parse_data_info`
        # return unsupported data.
        with pytest.raises(TypeError):
            self.dataset_type.parse_data_info = MagicMock(return_value='xxx')
            dataset = self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/dummy_annotation.json')
        with pytest.raises(TypeError):
            self.dataset_type.parse_data_info = MagicMock(
                return_value=[self.data_info, 'xxx'])
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/dummy_annotation.json')

    def test_meta(self):
        self._init_dataset()
        # test dataset.metainfo with setting the metainfo from annotation file
        # as the metainfo of self.base_dataset.
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
            dataset_type='test_dataset', task_name='test_task', empty_list=[])
        # test dataset.metainfo with setting METAINFO in self.base_dataset
        dataset_type = 'new_dataset'
        self.dataset_type.METAINFO = dict(
            dataset_type=dataset_type, classes=('dog', 'cat'))

        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
            dataset_type=dataset_type,
            task_name='test_task',
            classes=('dog', 'cat'),
            empty_list=[])
        # test dataset.metainfo with passing metainfo into self.base_dataset
        metainfo = dict(classes=('dog', ), task_name='new_task')
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            metainfo=metainfo)
        assert self.dataset_type.METAINFO == dict(
            dataset_type=dataset_type, classes=('dog', 'cat'))
            dataset_type=dataset_type,
            task_name='new_task',
            classes=('dog', ),
            empty_list=[])
        # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
        # change
        self.dataset_type.METAINFO['classes'] = ('dog', 'cat', 'fish')
        assert self.dataset_type.METAINFO == dict(
            dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))
            dataset_type=dataset_type,
            task_name='new_task',
            classes=('dog', ),
            empty_list=[])

        # test dataset.metainfo with passing metainfo containing a file into
        # self.base_dataset
            classes=osp.join(
                osp.dirname(__file__), '../data/meta/classes.txt'))
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            metainfo=metainfo)
        assert dataset.metainfo == dict(
            dataset_type=dataset_type,
            task_name='test_task',
            classes=['dog'],
            empty_list=[])

        # test dataset.metainfo with passing unsupported metainfo into
        # self.base_dataset
        with pytest.raises(TypeError):
            dataset = self.dataset_type(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/dummy_annotation.json',
        # test dataset.metainfo with passing metainfo into self.base_dataset
        # and lazy_init is True
        metainfo = dict(classes=('dog', ))
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=True)
        # 'task_name' and 'empty_list' not in dataset.metainfo
        assert dataset.metainfo == dict(
            dataset_type=dataset_type, classes=('dog', ))

        # test whether self.base_dataset.METAINFO is changed when a customize
        # dataset inherit self.base_dataset
        # test reset METAINFO in ToyDataset.
        class ToyDataset(self.dataset_type):
        assert ToyDataset.METAINFO == dict(xxx='xxx')
        assert self.dataset_type.METAINFO == dict(
            dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))

        # test update METAINFO in ToyDataset.
        class ToyDataset(self.dataset_type):
            METAINFO = copy.deepcopy(self.dataset_type.METAINFO)
            METAINFO['classes'] = ('bird', )
        assert ToyDataset.METAINFO == dict(
            dataset_type=dataset_type, classes=('bird', ))
        assert self.dataset_type.METAINFO == dict(
            dataset_type=dataset_type, classes=('dog', 'cat', 'fish'))

    @pytest.mark.parametrize('lazy_init', [True, False])
    def test_length(self, lazy_init):
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=lazy_init)
        if not lazy_init:
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
            assert len(dataset) == 3
        else:
            # test `__len__()` when lazy_init is True
            assert not dataset._fully_initialized
            # call `full_init()` automatically
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
    def test_compose(self):
        # test callable transform
        transforms = [function_pipeline]
        compose = Compose(transforms=transforms)
        assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
        # test transform build from cfg_dict
        transforms = [dict(type='CallableTransform')]
        compose = Compose(transforms=transforms)
        assert (self.imgs == compose(dict(img=self.imgs))['img']).all()
        # test return None in advance
        none_func = MagicMock(return_value=None)
        transforms = [none_func, function_pipeline]
        compose = Compose(transforms=transforms)
        assert compose(dict(img=self.imgs)) is None
        # test repr
        repr_str = f'Compose(\n' \
                   f'    {none_func}\n' \
                   f'    {function_pipeline}\n' \
                   f')'
        assert repr(compose) == repr_str
        # non-callable transform will raise error
        with pytest.raises(TypeError):
            transforms = [dict(type='NotCallableTransform')]
            Compose(transforms)

        # transform must be callable or dict
        with pytest.raises(TypeError):
            Compose([1])

    @pytest.mark.parametrize('lazy_init', [True, False])
    def test_getitem(self, lazy_init):
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=lazy_init)
        dataset.pipeline = self.pipeline
        if not lazy_init:
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
            assert dataset[0] == dict(imgs=self.imgs)
        else:
            # Test `__getitem__()` when lazy_init is True
            assert not dataset._fully_initialized
            assert not dataset.data_list
            # Call `full_init()` automatically
            assert dataset[0] == dict(imgs=self.imgs)
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
        # Test with test mode
        dataset.test_mode = False
        assert dataset[0] == dict(imgs=self.imgs)
        # Test cannot get a valid image.
        dataset.prepare_data = MagicMock(return_value=None)
        with pytest.raises(Exception):
            dataset[0]
        # Test get valid image by `_rand_another`
        def fake_prepare_data(idx):
            if idx == 0:
                return None
            else:
                return 1

        dataset.prepare_data = fake_prepare_data
        dataset[0]
        dataset.test_mode = True
        with pytest.raises(Exception):
            dataset[0]

    @pytest.mark.parametrize('lazy_init', [True, False])
    def test_get_data_info(self, lazy_init):
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=lazy_init)

        if not lazy_init:
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
            assert dataset.get_data_info(0) == self.data_info
        else:
            # test `get_data_info()` when lazy_init is True
            assert not dataset._fully_initialized
            # call `full_init()` automatically
            assert dataset.get_data_info(0) == self.data_info
            assert dataset._fully_initialized
            assert hasattr(dataset, 'data_list')
    def test_force_full_init(self):
        with pytest.raises(AttributeError):

            class ClassWithoutFullInit:

                @force_full_init
                def foo(self):
                    pass

            class_without_full_init = ClassWithoutFullInit()
            class_without_full_init.foo()

    def test_full_init(self):
        self._init_dataset()
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
        dataset.pipeline = self.pipeline
        # test `full_init()` when lazy_init is True
        assert not dataset._fully_initialized
        assert not dataset.data_list
        # call `full_init()` manually
        dataset.full_init()
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert len(dataset) == 3
        assert dataset[0] == dict(imgs=self.imgs)
        assert dataset.get_data_info(0) == self.data_info
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=False)

        dataset.pipeline = self.pipeline
        assert dataset._fully_initialized
        assert hasattr(dataset, 'data_list')
        assert len(dataset) == 3
        assert dataset[0] == dict(imgs=self.imgs)
        assert dataset.get_data_info(0) == self.data_info

        # test the instantiation of self.base_dataset when passing indices
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json')
        dataset_sliced = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json',
            indices=1)
        assert dataset_sliced[0] == dataset[0]
        assert len(dataset_sliced) == 1

    @pytest.mark.parametrize(
        'lazy_init, serialize_data',
        ([True, False], [False, True], [True, True], [False, False]))
    def test_get_subset_(self, lazy_init, serialize_data):
        # Test positive int indices.
        indices = 2
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=lazy_init,
            serialize_data=serialize_data)

        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(indices)
        assert len(dataset_copy) == 2
        for i in range(len(dataset_copy)):
            ori_data = dataset[i]
            assert dataset_copy[i] == ori_data

        # Test negative int indices.
        indices = -2
        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(indices)
        assert len(dataset_copy) == 2
        for i in range(len(dataset_copy)):
            ori_data = dataset[i + 1]
            ori_data['sample_idx'] = i
            assert dataset_copy[i] == ori_data

        # If indices is 0, return empty dataset.
        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(0)
        assert len(dataset_copy) == 0

        # Test list indices with positive element.
        indices = [1]
        dataset_copy = copy.deepcopy(dataset)
        ori_data = dataset[1]
        ori_data['sample_idx'] = 0
        dataset_copy.get_subset_(indices)
        assert len(dataset_copy) == 1
        assert dataset_copy[0] == ori_data

        # Test list indices with negative element.
        indices = [-1]
        dataset_copy = copy.deepcopy(dataset)
        ori_data = dataset[2]
        ori_data['sample_idx'] = 0
        dataset_copy.get_subset_(indices)
        assert len(dataset_copy) == 1
        assert dataset_copy[0] == ori_data

        # Test empty list.
        indices = []
        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(indices)
        assert len(dataset_copy) == 0
        # Test list with multiple positive indices.
        indices = [0, 1, 2]
        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(indices)
        for i in range(len(dataset_copy)):
            ori_data = dataset[i]
            ori_data['sample_idx'] = i
            assert dataset_copy[i] == ori_data
        # Test list with multiple negative indices.
        indices = [-1, -2, 0]
        dataset_copy = copy.deepcopy(dataset)
        dataset_copy.get_subset_(indices)
        for i in range(len(dataset_copy)):
            ori_data = dataset[len(dataset) - i - 1]
            ori_data['sample_idx'] = i
            assert dataset_copy[i] == ori_data

        with pytest.raises(TypeError):
            dataset.get_subset_(dict())

    @pytest.mark.parametrize(
        'lazy_init, serialize_data',
        ([True, False], [False, True], [True, True], [False, False]))
    def test_get_subset(self, lazy_init, serialize_data):
        # Test positive indices.
        indices = 2
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json',
            lazy_init=lazy_init,
            serialize_data=serialize_data)
        dataset_sliced = dataset.get_subset(indices)
        assert len(dataset_sliced) == 2
        assert dataset_sliced[0] == dataset[0]
        for i in range(len(dataset_sliced)):
            assert dataset_sliced[i] == dataset[i]
        # Test negative indices.
        indices = -2
        dataset_sliced = dataset.get_subset(indices)
        assert len(dataset_sliced) == 2
        for i in range(len(dataset_sliced)):
            ori_data = dataset[i + 1]
            ori_data['sample_idx'] = i
            assert dataset_sliced[i] == ori_data
        # If indices is 0 or empty list, return empty dataset.
        assert len(dataset.get_subset(0)) == 0
        assert len(dataset.get_subset([])) == 0
        # test list indices.
        indices = [1]
        dataset_sliced = dataset.get_subset(indices)
        ori_data = dataset[1]
        ori_data['sample_idx'] = 0
        assert len(dataset_sliced) == 1
        assert dataset_sliced[0] == ori_data
        # Test list with multiple positive index.
        indices = [0, 1, 2]
        dataset_sliced = dataset.get_subset(indices)
        for i in range(len(dataset_sliced)):
            ori_data = dataset[i]
            ori_data['sample_idx'] = i
            assert dataset_sliced[i] == ori_data
        # Test list with multiple negative index.
        indices = [-1, -2, 0]
        dataset_sliced = dataset.get_subset(indices)
        for i in range(len(dataset_sliced)):
            ori_data = dataset[len(dataset) - i - 1]
            ori_data['sample_idx'] = i
            assert dataset_sliced[i] == ori_data

    def test_rand_another(self):
        # test the instantiation of self.base_dataset when passing num_samples
        dataset = self.dataset_type(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img=None),
            ann_file='annotations/dummy_annotation.json',
        assert dataset._rand_another() >= 0
        assert dataset._rand_another() < len(dataset)


class TestConcatDataset:

    def setup(self):
        dataset = BaseDataset

        # create dataset_a
        data_info = dict(filename='test_img.jpg', height=604, width=640)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        imgs = torch.rand((2, 3, 32, 32))
        self.dataset_a = dataset(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs))

        # create dataset_b
        data_info = dict(filename='gray.jpg', height=288, width=512)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        imgs = torch.rand((2, 3, 32, 32))
        self.dataset_b = dataset(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs))
        # test init
        self.cat_datasets = ConcatDataset(
            datasets=[self.dataset_a, self.dataset_b])

    def test_init(self):
        # Test build dataset from cfg.
        dataset_cfg_b = dict(
            type=CustomDataset,
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b])
        cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline
        assert len(cat_datasets) == len(self.cat_datasets)
        for i in range(len(cat_datasets)):
            assert (cat_datasets.get_data_info(i) ==
                    self.cat_datasets.get_data_info(i))
            assert (cat_datasets[i] == self.cat_datasets[i])

        with pytest.raises(TypeError):
            ConcatDataset(datasets=[0])

    def test_full_init(self):
        # test init with lazy_init=True
        self.cat_datasets.full_init()
        assert len(self.cat_datasets) == 6
        self.cat_datasets.full_init()
        self.cat_datasets._fully_initialized = False
        self.cat_datasets[1]
        assert len(self.cat_datasets) == 6

        with pytest.raises(NotImplementedError):
            self.cat_datasets.get_subset_(1)

        with pytest.raises(NotImplementedError):
            self.cat_datasets.get_subset(1)
        # Different meta information will raise error.
        with pytest.raises(ValueError):
            dataset_b = BaseDataset(
                data_root=osp.join(osp.dirname(__file__), '../data/'),
                data_prefix=dict(img='imgs'),
                ann_file='annotations/dummy_annotation.json',
                metainfo=dict(classes=('cat')))
            ConcatDataset(datasets=[self.dataset_a, dataset_b])
    def test_metainfo(self):
        assert self.cat_datasets.metainfo == self.dataset_a.metainfo

    def test_length(self):
        assert len(self.cat_datasets) == (
            len(self.dataset_a) + len(self.dataset_b))

    def test_getitem(self):
        assert (
            self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all()
        assert (self.cat_datasets[0]['imgs'] !=
                self.dataset_b[0]['imgs']).all()
        assert (
            self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all()
        assert (self.cat_datasets[-1]['imgs'] !=
                self.dataset_a[-1]['imgs']).all()

    def test_get_data_info(self):
        assert self.cat_datasets.get_data_info(
            0) == self.dataset_a.get_data_info(0)
        assert self.cat_datasets.get_data_info(
            0) != self.dataset_b.get_data_info(0)

        assert self.cat_datasets.get_data_info(
            -1) == self.dataset_b.get_data_info(-1)
        assert self.cat_datasets.get_data_info(
            -1) != self.dataset_a.get_data_info(-1)

    def test_get_ori_dataset_idx(self):
        assert self.cat_datasets._get_ori_dataset_idx(3) == (
            1, 3 - len(self.dataset_a))
        assert self.cat_datasets._get_ori_dataset_idx(-1) == (
            1, len(self.dataset_b) - 1)
        with pytest.raises(ValueError):
            assert self.cat_datasets._get_ori_dataset_idx(-10)


class TestRepeatDataset:

    def setup(self):
        dataset = BaseDataset
        data_info = dict(filename='test_img.jpg', height=604, width=640)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        imgs = torch.rand((2, 3, 32, 32))
        self.dataset = dataset(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))

        self.repeat_times = 5
        # test init
        self.repeat_datasets = RepeatDataset(
            dataset=self.dataset, times=self.repeat_times)

    def test_init(self):
        # Test build dataset from cfg.
        dataset_cfg = dict(
            type=CustomDataset,
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        repeat_dataset = RepeatDataset(
            dataset=dataset_cfg, times=self.repeat_times)
        repeat_dataset.dataset.pipeline = self.dataset.pipeline
        assert len(repeat_dataset) == len(self.repeat_datasets)
        for i in range(len(repeat_dataset)):
            assert (repeat_dataset.get_data_info(i) ==
                    self.repeat_datasets.get_data_info(i))
            assert (repeat_dataset[i] == self.repeat_datasets[i])
        with pytest.raises(TypeError):
            RepeatDataset(dataset=[0], times=5)

    def test_full_init(self):
        self.repeat_datasets.full_init()
        assert len(
            self.repeat_datasets) == self.repeat_times * len(self.dataset)
        self.repeat_datasets.full_init()
        self.repeat_datasets._fully_initialized = False
        self.repeat_datasets[1]
        assert len(self.repeat_datasets) == \
               self.repeat_times * len(self.dataset)
        with pytest.raises(NotImplementedError):
            self.repeat_datasets.get_subset_(1)
        with pytest.raises(NotImplementedError):
            self.repeat_datasets.get_subset(1)

    def test_metainfo(self):
        assert self.repeat_datasets.metainfo == self.dataset.metainfo

    def test_length(self):
        assert len(
            self.repeat_datasets) == len(self.dataset) * self.repeat_times

    def test_getitem(self):
        for i in range(self.repeat_times):
            assert self.repeat_datasets[len(self.dataset) *
                                        i] == self.dataset[0]

    def test_get_data_info(self):
        for i in range(self.repeat_times):
            assert self.repeat_datasets.get_data_info(
                len(self.dataset) * i) == self.dataset.get_data_info(0)


class TestClassBalancedDataset:

    def setup(self):
        dataset = BaseDataset
        data_info = dict(filename='test_img.jpg', height=604, width=640)
        dataset.parse_data_info = MagicMock(return_value=data_info)
        imgs = torch.rand((2, 3, 32, 32))
        dataset.get_cat_ids = MagicMock(return_value=[0])
        self.dataset = dataset(
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs))

        self.repeat_indices = [0, 0, 1, 1, 1]
        # test init
        self.cls_banlanced_datasets = ClassBalancedDataset(
            dataset=self.dataset, oversample_thr=1e-3)
        self.cls_banlanced_datasets.repeat_indices = self.repeat_indices

    def test_init(self):
        # Test build dataset from cfg.
        dataset_cfg = dict(
            type=CustomDataset,
            data_root=osp.join(osp.dirname(__file__), '../data/'),
            data_prefix=dict(img='imgs'),
            ann_file='annotations/dummy_annotation.json')
        cls_banlanced_datasets = ClassBalancedDataset(
            dataset=dataset_cfg, oversample_thr=1e-3)
        cls_banlanced_datasets.repeat_indices = self.repeat_indices
        cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline
        assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets)
        for i in range(len(cls_banlanced_datasets)):
            assert (cls_banlanced_datasets.get_data_info(i) ==
                    self.cls_banlanced_datasets.get_data_info(i))
            assert (
                cls_banlanced_datasets[i] == self.cls_banlanced_datasets[i])

        with pytest.raises(TypeError):
            ClassBalancedDataset(dataset=[0], times=5)
    def test_full_init(self):
        self.cls_banlanced_datasets.full_init()
        self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
        assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)
        # Reinit `repeat_indices`.
        self.cls_banlanced_datasets._fully_initialized = False
        self.cls_banlanced_datasets.repeat_indices = self.repeat_indices
        assert len(self.cls_banlanced_datasets) != len(self.repeat_indices)
        with pytest.raises(NotImplementedError):
            self.cls_banlanced_datasets.get_subset_(1)

        with pytest.raises(NotImplementedError):
            self.cls_banlanced_datasets.get_subset(1)

    def test_metainfo(self):
        assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo

    def test_length(self):
        assert len(self.cls_banlanced_datasets) == len(self.repeat_indices)

    def test_getitem(self):
        for i in range(len(self.repeat_indices)):
            assert self.cls_banlanced_datasets[i] == self.dataset[
                self.repeat_indices[i]]

    def test_get_data_info(self):
        for i in range(len(self.repeat_indices)):
            assert self.cls_banlanced_datasets.get_data_info(
                i) == self.dataset.get_data_info(self.repeat_indices[i])

    def test_get_cat_ids(self):
        for i in range(len(self.repeat_indices)):
            assert self.cls_banlanced_datasets.get_cat_ids(
                i) == self.dataset.get_cat_ids(self.repeat_indices[i])