Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement
class TestBaseDataElement(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
gt_instances = BaseDataElement(
bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))
pred_instances = BaseDataElement(
bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
def is_equal(self, x, y):
assert type(x) == type(y)
if isinstance(
x, (int, float, str, list, tuple, dict, set, BaseDataElement)):
return x == y
elif isinstance(x, (torch.Tensor, np.ndarray)):
return (x == y).all()
def check_key_value(self, instances, metainfo=None, data=None):
# check the existence of keys in metainfo, data, and instances
if metainfo:
for k, v in metainfo.items():
assert k in instances
assert k in instances.all_keys()
assert k in instances.metainfo_keys()
assert k not in instances.keys()
assert self.is_equal(instances.get(k), v)
assert self.is_equal(getattr(instances, k), v)
if data:
for k, v in data.items():
assert k in instances
assert k in instances.keys()
assert k not in instances.metainfo_keys()
assert k in instances.all_keys()
assert self.is_equal(instances.get(k), v)
assert self.is_equal(getattr(instances, k), v)
def check_data_device(self, instances, device):
# assert instances.device == device
if isinstance(v, torch.Tensor):
assert v.device == torch.device(device)
elif isinstance(v, BaseDataElement):
self.check_data_device(v, device)
def check_data_dtype(self, instances, dtype):
if isinstance(v, (torch.Tensor, np.ndarray)):
assert isinstance(v, dtype)
if isinstance(v, BaseDataElement):
self.check_data_dtype(v, dtype)
def check_requires_grad(self, instances):
for v in instances.values():
if isinstance(v, torch.Tensor):
assert v.requires_grad is False
if isinstance(v, BaseDataElement):
self.check_requires_grad(v)
def test_init(self):
# initialization with no data and metainfo
metainfo, data = self.setup_data()
instances = BaseDataElement()
for k in metainfo:
assert k not in instances
assert instances.get(k, None) is None
for k in data:
assert k not in instances
assert instances.get(k, 'abc') == 'abc'
# initialization with kwargs
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo)
self.check_key_value(instances, metainfo)
instances = BaseDataElement(**data)
self.check_key_value(instances, data=data)
def test_new(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
# After deepcopy, the address of new data'element will be same as
# origin, but when change new data' element will not effect the origin
# element and will have new address
_, data = self.setup_data()
new_instances.set_data(data)
assert not self.is_equal(new_instances.gt_instances,
instances.gt_instances)
self.check_key_value(new_instances, metainfo, data)
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data()
new_instances.set_data(new_data)
assert id(new_instances.gt_instances) != id(data['gt_instances'])
self.check_key_value(new_instances, metainfo, new_data)
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo)
def test_clone(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
new_instances = instances.clone()
assert type(new_instances) == type(instances)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()
instances = BaseDataElement()
instances.set_metainfo(metainfo)
self.check_key_value(instances, metainfo=metainfo)
# test setting existing keys and new keys
new_metainfo, _ = self.setup_data()
new_metainfo.update(other=123)
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, metainfo=new_metainfo)
# test have the same key in data
_, data = self.setup_data()
instances = BaseDataElement(**data)
_, data = self.setup_data()
with self.assertRaises(AttributeError):
instances.set_metainfo(data)
with self.assertRaises(AssertionError):
instances.set_metainfo(123)
def test_set_data(self):
metainfo, data = self.setup_data()
instances = BaseDataElement()
instances.gt_instances = data['gt_instances']
instances.pred_instances = data['pred_instances']
self.check_key_value(instances, data=data)
# a.xx only set data rather than metainfo
instances.img_shape = metainfo['img_shape']
instances.img_id = metainfo['img_id']
self.check_key_value(instances, data=metainfo)
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
with self.assertRaises(AttributeError):
instances.img_shape = metainfo['img_shape']
# test set '_metainfo_fields' or '_data_fields'
with self.assertRaises(AttributeError):
instances._metainfo_fields = 1
with self.assertRaises(AttributeError):
instances._data_fields = 1
with self.assertRaises(AssertionError):
instances.set_data(123)
def test_update(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
proposals = BaseDataElement(
bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
new_instances = BaseDataElement(proposals=proposals)
instances.update(new_instances)
self.check_key_value(instances, metainfo,
data.update(dict(proposals=proposals)))
def test_delete_modify(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
new_metainfo, new_data = self.setup_data()
# avoid generating same metainfo, data
while True:
if new_metainfo['img_id'] == metainfo['img_id'] or new_metainfo[
'img_shape'] == metainfo['img_shape']:
new_metainfo, new_data = self.setup_data()
else:
break
instances.gt_instances = new_data['gt_instances']
instances.pred_instances = new_data['pred_instances']
# a.xx only set data rather than metainfo
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, new_metainfo, new_data)
assert not self.is_equal(instances.gt_instances, data['gt_instances'])
assert not self.is_equal(instances.pred_instances,
data['pred_instances'])
assert not self.is_equal(instances.img_id, metainfo['img_id'])
assert not self.is_equal(instances.img_shape, metainfo['img_shape'])
assert not self.is_equal(
instances.pop('pred_instances', None), data['pred_instances'])
with self.assertRaises(AttributeError):
del instances.pred_instances
assert 'gt_instances' not in instances
assert 'pred_instances' not in instances
assert instances.pop('gt_instances', None) is None
# test pop not exist key without default
with self.assertRaises(KeyError):
instances.pop('gt_instances')
assert instances.pop('pred_instances', 'abcdef') == 'abcdef'
assert instances.pop('img_id', None) is None
# test pop not exist key without default
with self.assertRaises(KeyError):
instances.pop('img_id')
assert instances.pop('img_shape') == new_metainfo['img_shape']
# test del '_metainfo_fields' or '_data_fields'
with self.assertRaises(AttributeError):
del instances._metainfo_fields
with self.assertRaises(AttributeError):
del instances._data_fields
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
cuda_instances = instances.cuda()
self.check_data_device(cuda_instances, 'cuda:0')
# here we further test to convert from cuda to cpu
cpu_instances = cuda_instances.cpu()
self.check_data_device(cpu_instances, 'cpu')
del cuda_instances
cuda_instances = instances.to('cuda:0')
self.check_data_device(cuda_instances, 'cuda:0')
def test_cpu(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
self.check_data_device(instances, 'cpu')
cpu_instances = instances.cpu()
# assert cpu_instances.device == 'cpu'
assert cpu_instances.gt_instances.bboxes.device == torch.device('cpu')
assert cpu_instances.gt_instances.labels.device == torch.device('cpu')
def test_numpy_tensor(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
np_instances = instances.numpy()
self.check_data_dtype(np_instances, np.ndarray)
tensor_instances = np_instances.to_tensor()
self.check_data_dtype(tensor_instances, torch.Tensor)
def test_detach(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
instances.detach()
self.check_requires_grad(instances)
metainfo = dict(img_shape=(800, 1196, 3))
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
gt_instances = BaseDataElement(
metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances)
address = hex(id(sample))
address_gt_instances = hex(id(sample.gt_instances))
assert repr(sample) == (
'<BaseDataElement(\n\n'
' META INFORMATION\n'
' img_shape: (800, 1196, 3)\n\n'
' DATA FIELDS\n'
' gt_instances: <BaseDataElement(\n \n'
' META INFORMATION\n'
' img_shape: (800, 1196, 3)\n \n'
' DATA FIELDS\n'
' det_labels: tensor([0, 1, 2, 3])\n'
f' ) at {address_gt_instances}>\n'
f') at {address}>')
def test_set_fields(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo)
for key, value in data.items():
instances.set_field(name=key, value=value, dtype=BaseDataElement)
self.check_key_value(instances, data=data)
# test type check
_, data = self.setup_data()
instances = BaseDataElement()
for key, value in data.items():
with self.assertRaises(AssertionError):
instances.set_field(name=key, value=value, dtype=torch.Tensor)
def test_inheritance(self):
class DetDataSample(BaseDataElement):
@property
def proposals(self):
return self._proposals
@proposals.setter
def proposals(self, value):
self.set_field(
value=value, name='_proposals', dtype=BaseDataElement)
@proposals.deleter
def proposals(self):
del self._proposals
@property
def gt_instances(self):
return self._gt_instances
@gt_instances.setter
def gt_instances(self, value):
self.set_field(
value=value, name='_gt_instances', dtype=BaseDataElement)
@gt_instances.deleter
def gt_instances(self):
del self._gt_instances
@property
def pred_instances(self):
return self._pred_instances
@pred_instances.setter
def pred_instances(self, value):
self.set_field(
value=value, name='_pred_instances', dtype=BaseDataElement)
@pred_instances.deleter
def pred_instances(self):
del self._pred_instances
det_sample = DetDataSample()
# test set
proposals = BaseDataElement(bboxes=torch.rand((5, 4)))
det_sample.proposals = proposals
assert 'proposals' in det_sample
# test get
assert det_sample.proposals == proposals
# test delete
del det_sample.proposals
assert 'proposals' not in det_sample
# test the data whether meet the requirements
with self.assertRaises(AssertionError):
det_sample.proposals = torch.rand((5, 4))
def test_values(self):
# test_metainfo_values
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
assert len(instances.metainfo_values()) == len(metainfo.values())
# test_all_values
assert len(instances.all_values()) == len(metainfo.values()) + len(
data.values())
# test_values
assert len(instances.values()) == len(data.values())
def test_keys(self):
# test_metainfo_keys
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
assert len(instances.metainfo_keys()) == len(metainfo.keys())
# test_all_keys
assert len(
instances.all_keys()) == len(data.keys()) + len(metainfo.keys())
# test_keys
assert len(instances.keys()) == len(data.keys())
def test_items(self):
# test_metainfo_items
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
assert len(dict(instances.metainfo_items())) == len(
dict(metainfo.items()))
# test_all_items
assert len(dict(instances.all_items())) == len(dict(
metainfo.items())) + len(dict(data.items()))
# test_items
assert len(dict(instances.items())) == len(dict(data.items()))
def test_to_dict(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
dict_instances = instances.to_dict()
# test convert BaseDataElement to dict
for k in instances.all_keys():
# all keys in instances should be in dict_instances
assert k in dict_instances
assert isinstance(dict_instances, dict)
# sub data element should also be converted to dict
assert isinstance(dict_instances['gt_instances'], dict)
assert isinstance(dict_instances['pred_instances'], dict)
def test_metainfo(self):
# test metainfo property
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
self.assertDictEqual(instances.metainfo, metainfo)