Skip to content
Snippets Groups Projects
Unverified Commit 41e1191c authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

add unit tests of data abstract interface (#21)


* add unit tests of data abstract interface

* update

* update

* update docs of data element

* a draft of UT of datasample, to be finished

* update datasample test

* updata

* update

* fix comments

* fix comments

* fix comments

Co-authored-by: default avatarliukuikun <liukuikun@sensetime.com>
parent bbb7d625
No related branches found
No related tags found
No related merge requests found
...@@ -113,13 +113,13 @@ gt_instances2 = gt_instances1.new() ...@@ -113,13 +113,13 @@ gt_instances2 = gt_instances1.new()
gt_instances = BaseDataElement() gt_instances = BaseDataElement()
# 设置 gt_instances 的 meta 字段,img_id 和 img_shape 会被作为 metainfo 的字段成为 gt_instances 的属性 # 设置 gt_instances 的 meta 字段,img_id 和 img_shape 会被作为 metainfo 的字段成为 gt_instances 的属性
gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))
assert 'img_shape' in gt_instaces.metainfo_keys() assert 'img_shape' in gt_instances.metainfo_keys()
# 'img_shape' 是 gt_instances 的属性 # 'img_shape' 是 gt_instances 的属性
assert 'img_shape' in gt_instaces assert 'img_shape' in gt_instances
# img_shape 不是 gt_instances 的 data 字段 # img_shape 不是 gt_instances 的 data 字段
assert 'img_shape' not in gt_instaces.data_keys() assert 'img_shape' not in gt_instances.data_keys()
# 通过 keys 来访问所有属性 # 通过 keys 来访问所有属性
assert 'img_shape' in gt_instaces.keys() assert 'img_shape' in gt_instances.keys()
# 访问类属性一样访问 'img_shape' # 访问类属性一样访问 'img_shape'
print(gt_instances.img_shape) print(gt_instances.img_shape)
...@@ -179,7 +179,7 @@ gt_instances.get('bboxes', None) # 6x4 tensor ...@@ -179,7 +179,7 @@ gt_instances.get('bboxes', None) # 6x4 tensor
# 属性的删除 # 属性的删除
del gt_instances.img_shape del gt_instances.img_shape
del gt_instances.bboxes del gt_instances.bboxes
assert 'img_shape' in gt_instances assert 'img_shape' not in gt_instances
assert 'bboxes' not in gt_instances assert 'bboxes' not in gt_instances
# 提供了便捷的属性删除和访问操作 pop # 提供了便捷的属性删除和访问操作 pop
...@@ -191,6 +191,7 @@ gt_instances.pop('bboxes', None) # None ...@@ -191,6 +191,7 @@ gt_instances.pop('bboxes', None) # None
用户可以像 torch.Tensor 那样对 `BaseDataElement` 的 data 进行状态转换,目前支持 `cuda``cpu``to``numpy` 等操作。 用户可以像 torch.Tensor 那样对 `BaseDataElement` 的 data 进行状态转换,目前支持 `cuda``cpu``to``numpy` 等操作。
其中,`to` 函数拥有和 `torch.Tensor.to()` 相同的接口,使得用户可以灵活地将被封装的 tensor 进行状态转换。 其中,`to` 函数拥有和 `torch.Tensor.to()` 相同的接口,使得用户可以灵活地将被封装的 tensor 进行状态转换。
**注意:** 这些接口只会处理类型为 np.array,torch.Tensor,或者数字的序列,其他属性的数据(如字符串)会被跳过处理。
```python ```python
# 将所有 data 转移到 GPU 上 # 将所有 data 转移到 GPU 上
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement
class TestBaseDataElement(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
return metainfo, data
def check_key_value(self, instances, metainfo=None, data=None):
# check the existence of keys in metainfo, data, and instances
if metainfo:
for k, v in metainfo.items():
assert k in instances
assert k in instances.keys()
assert k in instances.metainfo_keys()
assert k not in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
if data:
for k, v in data.items():
assert k in instances
assert k in instances.keys()
assert k not in instances.metainfo_keys()
assert k in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
def check_data_device(self, instances, device):
assert instances.device == device
for v in instances.data_values():
if isinstance(v, torch.Tensor):
assert v.device == device
def check_data_dtype(self, instances, dtype):
for v in instances.data_values():
if isinstance(v, (torch.Tensor, np.ndarray)):
assert isinstance(v, dtype)
def test_init(self):
# initialization with no data and metainfo
metainfo, data = self.setup_data()
instances = BaseDataElement()
for k in metainfo:
assert k not in instances
assert instances.get(k, None) is None
for k in data:
assert k not in instances
assert instances.get(k, 'abc') == 'abc'
# initialization with kwargs
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, data=data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo)
self.check_key_value(instances, metainfo)
instances = BaseDataElement(data=data)
self.check_key_value(instances, data=data)
def test_new(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, data=data)
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert id(new_instances.bboxes) != id(instances.bboxes)
assert id(new_instances.bboxes) != id(data['bboxes'])
self.check_key_value(new_instances, metainfo, data)
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data)
assert type(new_instances) == type(instances)
assert id(new_instances.bboxes) != id(instances.bboxes)
assert id(new_instances.bboxes) != id(data['bboxes'])
self.check_key_value(new_instances, metainfo, data)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()
instances = BaseDataElement()
instances.set_metainfo(metainfo)
self.check_key_value(instances, metainfo=metainfo)
# test setting existing keys and new keys
new_metainfo, _ = self.setup_data()
new_metainfo.update(other=123)
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, metainfo=new_metainfo)
def test_set_data(self):
metainfo, data = self.setup_data()
instances = BaseDataElement()
instances.bboxes = data['bboxes']
instances.scores = data['scores']
self.check_key_value(instances, data=data)
# a.xx only set data rather than metainfo
instances.img_shape = metainfo['img_shape']
instances.img_id = metainfo['img_id']
self.check_key_value(instances, data=metainfo)
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
with self.assertRaises(AssertionError):
instances.img_shape = metainfo['img_shape']
def test_delete_modify(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
new_metainfo, new_data = self.setup_data()
instances.bboxes = new_data['bboxes']
instances.scores = new_data['scores']
# a.xx only set data rather than metainfo
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, new_metainfo, new_data)
assert instances.bboxes != data['bboxes']
assert instances.scores != data['scores']
assert instances.img_id != metainfo['img_id']
assert instances.img_shape != metainfo['img_shape']
del instances.bboxes
assert instances.pop('scores', None) == new_data['scores']
with self.assertRaises(AttributeError):
del instances.scores
assert 'bboxes' not in instances
assert 'scores' not in instances
assert instances.pop('bboxes', None) is None
assert instances.pop('scores', 'abcdef') == 'abcdef'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
cuda_instances = instances.cuda()
self.check_data_device(instances, 'cuda:0')
# here we further test to convert from cuda to cpu
cpu_instances = cuda_instances.cpu()
self.check_data_device(cpu_instances, 'cpu')
del cuda_instances
cuda_instances = instances.to('cuda:0')
self.check_data_device(cuda_instances, 'cuda:0')
def test_cpu(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
self.check_data_device(instances, 'cpu')
cpu_instances = instances.cpu()
assert cpu_instances.device == 'cpu'
assert cpu_instances.bboxes.device == 'cpu'
assert cpu_instances.scores.device == 'cpu'
def test_numpy_tensor(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo, data)
np_instances = instances.numpy()
self.check_data_dtype(np_instances, np.ndarray)
tensor_instances = instances.to_tensor()
self.check_data_dtype(tensor_instances, torch.Tensor)
def test_repr(self):
metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
instances = BaseDataElement(metainfo=metainfo)
instances.det_labels = torch.LongTensor([0, 1, 2, 3])
instances.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3])
assert repr(instances) == ('<BaseDataElement(\n'
' META INFORMATION\n'
'img_shape: (800, 1196, 3)\n'
'pad_shape: (800, 1216, 3)\n'
' DATA FIELDS\n'
'shape of det_labels: torch.Size([4])\n'
'shape of det_scores: torch.Size([4])\n'
') at 0x7f84acd10f90>')
# Copyright (c) OpenMMLab. All rights reserved.
import random
from functools import partial
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement, BaseDataSample
class TestBaseDataSample(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
gt_instances = BaseDataElement(
data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
pred_instances = BaseDataElement(
data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
return metainfo, data
def check_key_value(self, instances, metainfo=None, data=None):
# check the existence of keys in metainfo, data, and instances
if metainfo:
for k, v in metainfo.items():
assert k in instances
assert k in instances.keys()
assert k in instances.metainfo_keys()
assert k not in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
if data:
for k, v in data.items():
assert k in instances
assert k in instances.keys()
assert k not in instances.metainfo_keys()
assert k in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
def check_data_device(self, instances, device):
assert instances.device == device
for v in instances.data_values():
if isinstance(v, torch.Tensor):
assert v.device == device
elif isinstance(v, BaseDataElement):
self.check_data_device(v, device)
def check_data_dtype(self, instances, dtype):
for v in instances.data_values():
if isinstance(v, (torch.Tensor, np.ndarray)):
assert isinstance(v, dtype)
if isinstance(v, BaseDataElement):
self.check_data_dtype(v, dtype)
def test_init(self):
# initialization with no data and metainfo
metainfo, data = self.setup_data()
instances = BaseDataSample()
for k in metainfo:
assert k not in instances
assert instances.get(k, None) is None
for k in data:
assert k not in instances
assert instances.get(k, 'abc') == 'abc'
# initialization with kwargs
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo, data=data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo)
self.check_key_value(instances, metainfo)
instances = BaseDataSample(data=data)
self.check_key_value(instances, data=data)
def test_new(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo, data=data)
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert id(new_instances.data) != id(instances.data)
assert id(new_instances.bboxes) != id(data)
self.check_key_value(new_instances, metainfo, data)
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data)
assert type(new_instances) == type(instances)
assert id(new_instances.data) != id(instances.data)
assert id(new_instances.data) != id(data)
self.check_key_value(new_instances, metainfo, data)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()
instances = BaseDataSample()
instances.set_metainfo(metainfo)
self.check_key_value(instances, metainfo=metainfo)
# test setting existing keys and new keys
new_metainfo, _ = self.setup_data()
new_metainfo.update(other=123)
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, metainfo=new_metainfo)
def test_set_data(self):
metainfo, data = self.setup_data()
instances = BaseDataSample()
instances.gt_instances = data['gt_instances']
instances.pred_instances = data['pred_instances']
self.check_key_value(instances, data=data)
# a.xx only set data rather than metainfo
instances.img_shape = metainfo['img_shape']
instances.img_id = metainfo['img_id']
self.check_key_value(instances, data=metainfo)
def test_delete_modify(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
new_metainfo, new_data = self.setup_data()
instances.gt_instances = new_data['gt_instances']
instances.pred_instances = new_data['pred_instances']
# a.xx only set data rather than metainfo
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, new_metainfo, new_data)
assert instances.gt_instances != data['gt_instances']
assert instances.pred_instances != data['pred_instances']
assert instances.img_id != metainfo['img_id']
assert instances.img_shape != metainfo['img_shape']
del instances.gt_instances
assert instances.pop('pred_instances',
None) == new_data['pred_instances']
# pred_instances has been deleted,
# instances does not have the pred_instances
with self.assertRaises(AttributeError):
del instances.pred_instances
assert 'gt_instances' not in instances
assert 'pred_instances' not in instances
assert instances.pop('gt_instances', None) is None
assert instances.pop('pred_instances', 'abcdef') == 'abcdef'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
cuda_instances = instances.cuda()
self.check_data_device(instances, 'cuda:0')
# here we further test to convert from cuda to cpu
cpu_instances = cuda_instances.cpu()
self.check_data_device(cpu_instances, 'cpu')
del cuda_instances
cuda_instances = instances.to('cuda:0')
self.check_data_device(cuda_instances, 'cuda:0')
def test_cpu(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
self.check_data_device(instances, 'cpu')
cpu_instances = instances.cpu()
assert cpu_instances.device == 'cpu'
assert cpu_instances.bboxes.device == 'cpu'
assert cpu_instances.scores.device == 'cpu'
def test_numpy_tensor(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
np_instances = instances.numpy()
self.check_data_dtype(np_instances, np.ndarray)
tensor_instances = instances.to_tensor()
self.check_data_dtype(tensor_instances, torch.Tensor)
def test_repr(self):
metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
gt_instances = BaseDataElement(
data=dict(
det_labels=torch.LongTensor([0, 1, 2, 3],
det_scores=torch.Tensor(
[0.01, 0.1, 0.2, 0.3]))))
data = dict(gt_instances=gt_instances)
instances = BaseDataSample(metainfo=metainfo, data=data)
assert repr(instances) == ('<BaseDataSample(\n'
' META INFORMATION\n'
'img_shape: (800, 1196, 3)\n'
'pad_shape: (800, 1216, 3)\n'
' DATA FIELDS\n'
'\tgt_instances: <BaseDataElement(\n'
'\t META INFORMATION\n'
'\timg_shape: (800, 1196, 3)\n'
'\tpad_shape: (800, 1216, 3)\n'
'\t DATA FIELDS\n'
'\tshape of det_labels: torch.Size([4])\n'
'\tshape of det_scores: torch.Size([4])\n'
'\t) at 0x7f84acd10f90>'
') at 0x7f84acd10f90>')
def test_set_get_fields(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo)
for key, value in data.items():
instances._set_field(value, key, BaseDataElement)
self.check_key_value(instances, data=data)
# test type check
_, data = self.setup_data()
instances = BaseDataSample()
for key, value in data.items():
with self.assertRaises(AssertionError):
instances._set_field(value, key, BaseDataSample)
def test_del_field(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo)
for key, value in data.items():
instances._set_field(value, key, BaseDataElement)
instances._del_field('gt_instances')
instances._del_field('pred_instances')
# gt_instance has been deleted, instances does not have the gt_instance
with self.assertRaises(AttributeError):
instances._del_field('gt_instances')
assert 'gt_instances' not in instances
assert 'pred_instances' not in instances
def test_inheritance(self):
class DetDataSample(BaseDataSample):
proposals = property(
fget=partial(BaseDataSample._get_field, name='_proposals'),
fset=partial(
BaseDataSample._set_field,
name='_proposals',
dtype=BaseDataElement),
fdel=partial(BaseDataSample._del_field, name='_proposals'),
doc='Region proposals of an image')
gt_instances = property(
fget=partial(BaseDataSample._get_field, name='_gt_instances'),
fset=partial(
BaseDataSample._set_field,
name='_gt_instances',
dtype=BaseDataElement),
fdel=partial(BaseDataSample._del_field, name='_gt_instances'),
doc='Ground truth instances of an image')
pred_instances = property(
fget=partial(
BaseDataSample._get_field, name='_pred_instances'),
fset=partial(
BaseDataSample._set_field,
name='_pred_instances',
dtype=BaseDataElement),
fdel=partial(
BaseDataSample._del_field, name='_pred_instances'),
doc='Predicted instances of an image')
det_sample = DetDataSample()
# test set
proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4))))
det_sample.proposals = proposals
assert 'proposals' in det_sample
# test get
assert det_sample.proposals == proposals
# test delete
del det_sample.proposals
assert 'proposals' not in det_sample
# test the data whether meet the requirements
with self.assertRaises(AssertionError):
det_sample.proposals = torch.rand((5, 4))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment