Skip to content
Snippets Groups Projects
test_data_element.py 16.4 KiB
Newer Older
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase

import numpy as np
import pytest
import torch

from mmengine.data import BaseDataElement


class TestBaseDataElement(TestCase):

    def setup_data(self):
        metainfo = dict(
            img_id=random.randint(0, 100),
            img_shape=(random.randint(400, 600), random.randint(400, 600)))
        gt_instances = BaseDataElement(
            bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))
        pred_instances = BaseDataElement(
            bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
        data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
        return metainfo, data

    def is_equal(self, x, y):
        assert type(x) == type(y)
        if isinstance(
                x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
            return x == y
        elif isinstance(x, (torch.Tensor, np.ndarray)):
            return (x == y).all()

    def check_key_value(self, instances, metainfo=None, data=None):
        # check the existence of keys in metainfo, data, and instances
        if metainfo:
            for k, v in metainfo.items():
                assert k in instances
                assert k in instances.all_keys()
                assert k in instances.metainfo_keys()
                assert k not in instances.keys()
                assert self.is_equal(instances.get(k), v)
                assert self.is_equal(getattr(instances, k), v)
        if data:
            for k, v in data.items():
                assert k in instances
                assert k in instances.keys()
                assert k not in instances.metainfo_keys()
                assert k in instances.all_keys()
                assert self.is_equal(instances.get(k), v)
                assert self.is_equal(getattr(instances, k), v)

    def check_data_device(self, instances, device):
        # assert instances.device == device
        for v in instances.values():
            if isinstance(v, torch.Tensor):
                assert v.device == torch.device(device)
            elif isinstance(v, BaseDataElement):
                self.check_data_device(v, device)

    def check_data_dtype(self, instances, dtype):
        for v in instances.values():
            if isinstance(v, (torch.Tensor, np.ndarray)):
                assert isinstance(v, dtype)
            if isinstance(v, BaseDataElement):
                self.check_data_dtype(v, dtype)

    def check_requires_grad(self, instances):
        for v in instances.values():
            if isinstance(v, torch.Tensor):
                assert v.requires_grad is False
            if isinstance(v, BaseDataElement):
                self.check_requires_grad(v)

    def test_init(self):
        # initialization with no data and metainfo
        metainfo, data = self.setup_data()
        instances = BaseDataElement()
        for k in metainfo:
            assert k not in instances
            assert instances.get(k, None) is None
        for k in data:
            assert k not in instances
            assert instances.get(k, 'abc') == 'abc'

        # initialization with kwargs
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        self.check_key_value(instances, metainfo, data)

        # initialization with args
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo)
        self.check_key_value(instances, metainfo)
        instances = BaseDataElement(**data)
        self.check_key_value(instances, data=data)

    def test_new(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)

        # test new() with no arguments
        new_instances = instances.new()
        assert type(new_instances) == type(instances)
        # After deepcopy, the address of new data'element will be same as
        # origin, but when change new data' element will not effect the origin
        # element and will have new address
        _, data = self.setup_data()
        new_instances.set_data(data)
        assert not self.is_equal(new_instances.gt_instances,
                                 instances.gt_instances)
        self.check_key_value(new_instances, metainfo, data)

        # test new() with arguments
        metainfo, data = self.setup_data()
        new_instances = instances.new(metainfo=metainfo, **data)
        assert type(new_instances) == type(instances)
        assert id(new_instances.gt_instances) != id(instances.gt_instances)
        _, new_data = self.setup_data()
        new_instances.set_data(new_data)
        assert id(new_instances.gt_instances) != id(data['gt_instances'])
        self.check_key_value(new_instances, metainfo, new_data)

        metainfo, data = self.setup_data()
        new_instances = instances.new(metainfo=metainfo)
    def test_clone(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        new_instances = instances.clone()
        assert type(new_instances) == type(instances)

    def test_set_metainfo(self):
        metainfo, _ = self.setup_data()
        instances = BaseDataElement()
        instances.set_metainfo(metainfo)
        self.check_key_value(instances, metainfo=metainfo)

        # test setting existing keys and new keys
        new_metainfo, _ = self.setup_data()
        new_metainfo.update(other=123)
        instances.set_metainfo(new_metainfo)
        self.check_key_value(instances, metainfo=new_metainfo)

        # test have the same key in data
        _, data = self.setup_data()
        instances = BaseDataElement(**data)
        _, data = self.setup_data()
        with self.assertRaises(AttributeError):
            instances.set_metainfo(data)

        with self.assertRaises(AssertionError):
            instances.set_metainfo(123)

    def test_set_data(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement()

        instances.gt_instances = data['gt_instances']
        instances.pred_instances = data['pred_instances']
        self.check_key_value(instances, data=data)

        # a.xx only set data rather than metainfo
        instances.img_shape = metainfo['img_shape']
        instances.img_id = metainfo['img_id']
        self.check_key_value(instances, data=metainfo)

        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        with self.assertRaises(AttributeError):
            instances.img_shape = metainfo['img_shape']

        # test set '_metainfo_fields' or '_data_fields'
        with self.assertRaises(AttributeError):
            instances._metainfo_fields = 1
        with self.assertRaises(AttributeError):
            instances._data_fields = 1

        with self.assertRaises(AssertionError):
            instances.set_data(123)

    def test_update(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        proposals = BaseDataElement(
            bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
        new_instances = BaseDataElement(proposals=proposals)
        instances.update(new_instances)
        self.check_key_value(instances, metainfo,
                             data.update(dict(proposals=proposals)))

    def test_delete_modify(self):
liukuikun's avatar
liukuikun committed
        random.seed(10)
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)

        new_metainfo, new_data = self.setup_data()
liukuikun's avatar
liukuikun committed
        # avoid generating same metainfo, data
        while True:
            if new_metainfo['img_id'] == metainfo['img_id'] or new_metainfo[
                    'img_shape'] == metainfo['img_shape']:
                new_metainfo, new_data = self.setup_data()
            else:
                break
        instances.gt_instances = new_data['gt_instances']
        instances.pred_instances = new_data['pred_instances']

        # a.xx only set data rather than metainfo
        instances.set_metainfo(new_metainfo)
        self.check_key_value(instances, new_metainfo, new_data)

        assert not self.is_equal(instances.gt_instances, data['gt_instances'])
        assert not self.is_equal(instances.pred_instances,
                                 data['pred_instances'])
        assert not self.is_equal(instances.img_id, metainfo['img_id'])
        assert not self.is_equal(instances.img_shape, metainfo['img_shape'])
        del instances.gt_instances
        del instances.img_id
        assert not self.is_equal(
            instances.pop('pred_instances', None), data['pred_instances'])
        with self.assertRaises(AttributeError):
            del instances.pred_instances
        assert 'gt_instances' not in instances
        assert 'pred_instances' not in instances
        assert 'img_id' not in instances
        assert instances.pop('gt_instances', None) is None
        # test pop not exist key without default
        with self.assertRaises(KeyError):
            instances.pop('gt_instances')
        assert instances.pop('pred_instances', 'abcdef') == 'abcdef'
        assert instances.pop('img_id', None) is None
        # test pop not exist key without default
        with self.assertRaises(KeyError):
            instances.pop('img_id')
        assert instances.pop('img_shape') == new_metainfo['img_shape']

        # test del '_metainfo_fields' or '_data_fields'
        with self.assertRaises(AttributeError):
            del instances._metainfo_fields
        with self.assertRaises(AttributeError):
            del instances._data_fields

    @pytest.mark.skipif(
        not torch.cuda.is_available(), reason='GPU is required!')
    def test_cuda(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)

        cuda_instances = instances.cuda()
        self.check_data_device(cuda_instances, 'cuda:0')

        # here we further test to convert from cuda to cpu
        cpu_instances = cuda_instances.cpu()
        self.check_data_device(cpu_instances, 'cpu')
        del cuda_instances

        cuda_instances = instances.to('cuda:0')
        self.check_data_device(cuda_instances, 'cuda:0')

    def test_cpu(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        self.check_data_device(instances, 'cpu')

        cpu_instances = instances.cpu()
        # assert cpu_instances.device == 'cpu'
        assert cpu_instances.gt_instances.bboxes.device == torch.device('cpu')
        assert cpu_instances.gt_instances.labels.device == torch.device('cpu')

    def test_numpy_tensor(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)

        np_instances = instances.numpy()
        self.check_data_dtype(np_instances, np.ndarray)

        tensor_instances = np_instances.to_tensor()
        self.check_data_dtype(tensor_instances, torch.Tensor)

    def test_detach(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        instances.detach()
        self.check_requires_grad(instances)

    def test_repr(self):
        metainfo = dict(img_shape=(800, 1196, 3))
        gt_instances = BaseDataElement(
            metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
        sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances)
        address = hex(id(sample))
        address_gt_instances = hex(id(sample.gt_instances))
        assert repr(sample) == (
            '<BaseDataElement(\n\n'
            '    META INFORMATION\n'
            '    img_shape: (800, 1196, 3)\n\n'
            '    DATA FIELDS\n'
            '    gt_instances: <BaseDataElement(\n        \n'
            '            META INFORMATION\n'
            '            img_shape: (800, 1196, 3)\n        \n'
            '            DATA FIELDS\n'
            '            det_labels: tensor([0, 1, 2, 3])\n'
            f'        ) at {address_gt_instances}>\n'
            f') at {address}>')

    def test_set_fields(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo)
        for key, value in data.items():
            instances.set_field(name=key, value=value, dtype=BaseDataElement)
        self.check_key_value(instances, data=data)

        # test type check
        _, data = self.setup_data()
        instances = BaseDataElement()
        for key, value in data.items():
            with self.assertRaises(AssertionError):
                instances.set_field(name=key, value=value, dtype=torch.Tensor)

    def test_inheritance(self):

        class DetDataSample(BaseDataElement):

            @property
            def proposals(self):
                return self._proposals

            @proposals.setter
            def proposals(self, value):
                self.set_field(
                    value=value, name='_proposals', dtype=BaseDataElement)

            @proposals.deleter
            def proposals(self):
                del self._proposals

            @property
            def gt_instances(self):
                return self._gt_instances

            @gt_instances.setter
            def gt_instances(self, value):
                self.set_field(
                    value=value, name='_gt_instances', dtype=BaseDataElement)

            @gt_instances.deleter
            def gt_instances(self):
                del self._gt_instances

            @property
            def pred_instances(self):
                return self._pred_instances

            @pred_instances.setter
            def pred_instances(self, value):
                self.set_field(
                    value=value, name='_pred_instances', dtype=BaseDataElement)

            @pred_instances.deleter
            def pred_instances(self):
                del self._pred_instances

        det_sample = DetDataSample()

        # test set
        proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
        det_sample.proposals = proposals
        assert 'proposals' in det_sample

        # test get
        assert det_sample.proposals == proposals

        # test delete
        del det_sample.proposals
        assert 'proposals' not in det_sample

        # test the data whether meet the requirements
        with self.assertRaises(AssertionError):
            det_sample.proposals = torch.rand((5, 4))

    def test_values(self):
        # test_metainfo_values
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        assert len(instances.metainfo_values()) == len(metainfo.values())
        # test_all_values
        assert len(instances.all_values()) == len(metainfo.values()) + len(
            data.values())
        # test_values
        assert len(instances.values()) == len(data.values())

    def test_keys(self):
        # test_metainfo_keys
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        assert len(instances.metainfo_keys()) == len(metainfo.keys())

        # test_all_keys
        assert len(
            instances.all_keys()) == len(data.keys()) + len(metainfo.keys())
        # test_keys
        assert len(instances.keys()) == len(data.keys())

    def test_items(self):
        # test_metainfo_items
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        assert len(dict(instances.metainfo_items())) == len(
            dict(metainfo.items()))
        # test_all_items
        assert len(dict(instances.all_items())) == len(dict(
            metainfo.items())) + len(dict(data.items()))

        # test_items
        assert len(dict(instances.items())) == len(dict(data.items()))

    def test_to_dict(self):
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        dict_instances = instances.to_dict()
        # test convert BaseDataElement to dict
        for k in instances.all_keys():
            # all keys in instances should be in dict_instances
            assert k in dict_instances
        assert isinstance(dict_instances, dict)
        # sub data element should also be converted to dict
        assert isinstance(dict_instances['gt_instances'], dict)
        assert isinstance(dict_instances['pred_instances'], dict)

    def test_metainfo(self):
        # test metainfo property
        metainfo, data = self.setup_data()
        instances = BaseDataElement(metainfo=metainfo, **data)
        self.assertDictEqual(instances.metainfo, metainfo)