-
Wenwei Zhang authored
* add unit tests of data abstract interface * update * update * update docs of data element * a draft of UT of datasample, to be finished * update datasample test * updata * update * fix comments * fix comments * fix comments Co-authored-by:
liukuikun <liukuikun@sensetime.com>
Wenwei Zhang authored* add unit tests of data abstract interface * update * update * update docs of data element * a draft of UT of datasample, to be finished * update datasample test * updata * update * fix comments * fix comments * fix comments Co-authored-by:
liukuikun <liukuikun@sensetime.com>
test_data_element.py 7.68 KiB
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement
class TestBaseDataElement(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
return metainfo, data
def check_key_value(self, instances, metainfo=None, data=None):
# check the existence of keys in metainfo, data, and instances
if metainfo:
for k, v in metainfo.items():
assert k in instances
assert k in instances.keys()
assert k in instances.metainfo_keys()
assert k not in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
if data:
for k, v in data.items():
assert k in instances
assert k in instances.keys()
assert k not in instances.metainfo_keys()
assert k in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
def check_data_device(self, instances, device):
assert instances.device == device
for v in instances.data_values():
if isinstance(v, torch.Tensor):
assert v.device == device
def check_data_dtype(self, instances, dtype):
for v in instances.data_values():
if isinstance(v, (torch.Tensor, np.ndarray)):
assert isinstance(v, dtype)
def test_init(self):
# initialization with no data and metainfo
metainfo, data = self.setup_data()
instances = BaseDataElement()
for k in metainfo:
assert k not in instances
assert instances.get(k, None) is None
for k in data:
assert k not in instances
assert instances.get(k, 'abc') == 'abc'
# initialization with kwargs
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, data=data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo)
self.check_key_value(instances, metainfo)
instances = BaseDataElement(data=data)
self.check_key_value(instances, data=data)
def test_new(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, data=data)
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert id(new_instances.bboxes) != id(instances.bboxes)
assert id(new_instances.bboxes) != id(data['bboxes'])
self.check_key_value(new_instances, metainfo, data)
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data)
assert type(new_instances) == type(instances)
assert id(new_instances.bboxes) != id(instances.bboxes)
assert id(new_instances.bboxes) != id(data['bboxes'])
self.check_key_value(new_instances, metainfo, data)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()
instances = BaseDataElement()
instances.set_metainfo(metainfo)
self.check_key_value(instances, metainfo=metainfo)
# test setting existing keys and new keys
new_metainfo, _ = self.setup_data()
new_metainfo.update(other=123)
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, metainfo=new_metainfo)
def test_set_data(self):
metainfo, data = self.setup_data()
instances = BaseDataElement()
instances.bboxes = data['bboxes']
instances.scores = data['scores']
self.check_key_value(instances, data=data)
# a.xx only set data rather than metainfo
instances.img_shape = metainfo['img_shape']
instances.img_id = metainfo['img_id']
self.check_key_value(instances, data=metainfo)
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
with self.assertRaises(AssertionError):
instances.img_shape = metainfo['img_shape']
def test_delete_modify(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
new_metainfo, new_data = self.setup_data()
instances.bboxes = new_data['bboxes']
instances.scores = new_data['scores']
# a.xx only set data rather than metainfo
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, new_metainfo, new_data)
assert instances.bboxes != data['bboxes']
assert instances.scores != data['scores']
assert instances.img_id != metainfo['img_id']
assert instances.img_shape != metainfo['img_shape']
del instances.bboxes
assert instances.pop('scores', None) == new_data['scores']
with self.assertRaises(AttributeError):
del instances.scores
assert 'bboxes' not in instances
assert 'scores' not in instances
assert instances.pop('bboxes', None) is None
assert instances.pop('scores', 'abcdef') == 'abcdef'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
cuda_instances = instances.cuda()
self.check_data_device(instances, 'cuda:0')
# here we further test to convert from cuda to cpu
cpu_instances = cuda_instances.cpu()
self.check_data_device(cpu_instances, 'cpu')
del cuda_instances
cuda_instances = instances.to('cuda:0')
self.check_data_device(cuda_instances, 'cuda:0')
def test_cpu(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
self.check_data_device(instances, 'cpu')
cpu_instances = instances.cpu()
assert cpu_instances.device == 'cpu'
assert cpu_instances.bboxes.device == 'cpu'
assert cpu_instances.scores.device == 'cpu'
def test_numpy_tensor(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
np_instances = instances.numpy()
self.check_data_dtype(np_instances, np.ndarray)
tensor_instances = instances.to_tensor()
self.check_data_dtype(tensor_instances, torch.Tensor)
def test_repr(self):
metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instances = BaseDataElement(metainfo=metainfo)
instances.det_labels = torch.LongTensor([0, 1, 2, 3])
instances.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3])
assert repr(instances) == ('<BaseDataElement(\n'
' META INFORMATION\n'
'img_shape: (800, 1196, 3)\n'
'pad_shape: (800, 1216, 3)\n'
' DATA FIELDS\n'
'shape of det_labels: torch.Size([4])\n'
'shape of det_scores: torch.Size([4])\n'
') at 0x7f84acd10f90>')