From 41e1191cbc47f84a0598732baa9ea1c32e315cc7 Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Wed, 16 Feb 2022 23:19:18 +0800 Subject: [PATCH] add unit tests of data abstract interface (#21) * add unit tests of data abstract interface * update * update * update docs of data element * a draft of UT of datasample, to be finished * update datasample test * updata * update * fix comments * fix comments * fix comments Co-authored-by: liukuikun <liukuikun@sensetime.com> --- .../tutorials/abstract_data_interface.md | 11 +- tests/test_data/test_data_element.py | 204 ++++++++++++ tests/test_data/test_data_sample.py | 298 ++++++++++++++++++ 3 files changed, 508 insertions(+), 5 deletions(-) create mode 100644 tests/test_data/test_data_element.py create mode 100644 tests/test_data/test_data_sample.py diff --git a/docs/zh_cn/tutorials/abstract_data_interface.md b/docs/zh_cn/tutorials/abstract_data_interface.md index ce85d7d7..6f2e2905 100644 --- a/docs/zh_cn/tutorials/abstract_data_interface.md +++ b/docs/zh_cn/tutorials/abstract_data_interface.md @@ -113,13 +113,13 @@ gt_instances2 = gt_instances1.new() gt_instances = BaseDataElement() # 设置 gt_instances çš„ meta å—段,img_id å’Œ img_shape 会被作为 metainfo çš„å—段æˆä¸º gt_instances 的属性 gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) -assert 'img_shape' in gt_instaces.metainfo_keys() +assert 'img_shape' in gt_instances.metainfo_keys() # 'img_shape' 是 gt_instances 的属性 -assert 'img_shape' in gt_instaces +assert 'img_shape' in gt_instances # img_shape ä¸æ˜¯ gt_instances çš„ data å—段 -assert 'img_shape' not in gt_instaces.data_keys() +assert 'img_shape' not in gt_instances.data_keys() # 通过 keys æ¥è®¿é—®æ‰€æœ‰å±žæ€§ -assert 'img_shape' in gt_instaces.keys() +assert 'img_shape' in gt_instances.keys() # è®¿é—®ç±»å±žæ€§ä¸€æ ·è®¿é—® 'img_shape' print(gt_instances.img_shape) @@ -179,7 +179,7 @@ gt_instances.get('bboxes', None) # 6x4 tensor # å±žæ€§çš„åˆ é™¤ del gt_instances.img_shape del gt_instances.bboxes -assert 'img_shape' in gt_instances +assert 'img_shape' not in gt_instances assert 'bboxes' not in gt_instances # æ供了便æ·çš„å±žæ€§åˆ é™¤å’Œè®¿é—®æ“作 pop @@ -191,6 +191,7 @@ gt_instances.pop('bboxes', None) # None 用户å¯ä»¥åƒ torch.Tensor é‚£æ ·å¯¹ `BaseDataElement` çš„ data 进行状æ€è½¬æ¢ï¼Œç›®å‰æ”¯æŒ `cuda`, `cpu`, `to`, `numpy` ç‰æ“作。 å…¶ä¸ï¼Œ`to` 函数拥有和 `torch.Tensor.to()` 相åŒçš„接å£ï¼Œä½¿å¾—用户å¯ä»¥çµæ´»åœ°å°†è¢«å°è£…çš„ tensor 进行状æ€è½¬æ¢ã€‚ +**注æ„:** 这些接å£åªä¼šå¤„ç†ç±»åž‹ä¸º np.array,torch.Tensor,或者数å—çš„åºåˆ—,其他属性的数æ®ï¼ˆå¦‚å—符串)会被跳过处ç†ã€‚ ```python # 将所有 data 转移到 GPU 上 diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py new file mode 100644 index 00000000..4bf8c722 --- /dev/null +++ b/tests/test_data/test_data_element.py @@ -0,0 +1,204 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from unittest import TestCase + +import numpy as np +import pytest +import torch + +from mmengine.data import BaseDataElement + + +class TestBaseDataElement(TestCase): + + def setup_data(self): + metainfo = dict( + img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), random.randint(400, 600))) + data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + return metainfo, data + + def check_key_value(self, instances, metainfo=None, data=None): + # check the existence of keys in metainfo, data, and instances + if metainfo: + for k, v in metainfo.items(): + assert k in instances + assert k in instances.keys() + assert k in instances.metainfo_keys() + assert k not in instances.data_keys() + assert instances.get(k) == v + assert getattr(instances, k) == v + if data: + for k, v in data.items(): + assert k in instances + assert k in instances.keys() + assert k not in instances.metainfo_keys() + assert k in instances.data_keys() + assert instances.get(k) == v + assert getattr(instances, k) == v + + def check_data_device(self, instances, device): + assert instances.device == device + for v in instances.data_values(): + if isinstance(v, torch.Tensor): + assert v.device == device + + def check_data_dtype(self, instances, dtype): + for v in instances.data_values(): + if isinstance(v, (torch.Tensor, np.ndarray)): + assert isinstance(v, dtype) + + def test_init(self): + # initialization with no data and metainfo + metainfo, data = self.setup_data() + instances = BaseDataElement() + for k in metainfo: + assert k not in instances + assert instances.get(k, None) is None + for k in data: + assert k not in instances + assert instances.get(k, 'abc') == 'abc' + + # initialization with kwargs + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo, data=data) + self.check_key_value(instances, metainfo, data) + + # initialization with args + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + self.check_key_value(instances, metainfo, data) + + # initialization with args + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo) + self.check_key_value(instances, metainfo) + instances = BaseDataElement(data=data) + self.check_key_value(instances, data=data) + + def test_new(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo, data=data) + + # test new() with no arguments + new_instances = instances.new() + assert type(new_instances) == type(instances) + assert id(new_instances.bboxes) != id(instances.bboxes) + assert id(new_instances.bboxes) != id(data['bboxes']) + self.check_key_value(new_instances, metainfo, data) + + # test new() with arguments + metainfo, data = self.setup_data() + new_instances = instances.new(metainfo=metainfo, data=data) + assert type(new_instances) == type(instances) + assert id(new_instances.bboxes) != id(instances.bboxes) + assert id(new_instances.bboxes) != id(data['bboxes']) + self.check_key_value(new_instances, metainfo, data) + + def test_set_metainfo(self): + metainfo, _ = self.setup_data() + instances = BaseDataElement() + instances.set_metainfo(metainfo) + self.check_key_value(instances, metainfo=metainfo) + + # test setting existing keys and new keys + new_metainfo, _ = self.setup_data() + new_metainfo.update(other=123) + instances.set_metainfo(new_metainfo) + self.check_key_value(instances, metainfo=new_metainfo) + + def test_set_data(self): + metainfo, data = self.setup_data() + instances = BaseDataElement() + + instances.bboxes = data['bboxes'] + instances.scores = data['scores'] + self.check_key_value(instances, data=data) + + # a.xx only set data rather than metainfo + instances.img_shape = metainfo['img_shape'] + instances.img_id = metainfo['img_id'] + self.check_key_value(instances, data=metainfo) + + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + with self.assertRaises(AssertionError): + instances.img_shape = metainfo['img_shape'] + + def test_delete_modify(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + + new_metainfo, new_data = self.setup_data() + instances.bboxes = new_data['bboxes'] + instances.scores = new_data['scores'] + + # a.xx only set data rather than metainfo + instances.set_metainfo(new_metainfo) + self.check_key_value(instances, new_metainfo, new_data) + + assert instances.bboxes != data['bboxes'] + assert instances.scores != data['scores'] + assert instances.img_id != metainfo['img_id'] + assert instances.img_shape != metainfo['img_shape'] + + del instances.bboxes + assert instances.pop('scores', None) == new_data['scores'] + with self.assertRaises(AttributeError): + del instances.scores + + assert 'bboxes' not in instances + assert 'scores' not in instances + assert instances.pop('bboxes', None) is None + assert instances.pop('scores', 'abcdef') == 'abcdef' + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='GPU is required!') + def test_cuda(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + + cuda_instances = instances.cuda() + self.check_data_device(instances, 'cuda:0') + + # here we further test to convert from cuda to cpu + cpu_instances = cuda_instances.cpu() + self.check_data_device(cpu_instances, 'cpu') + del cuda_instances + + cuda_instances = instances.to('cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + def test_cpu(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + self.check_data_device(instances, 'cpu') + + cpu_instances = instances.cpu() + assert cpu_instances.device == 'cpu' + assert cpu_instances.bboxes.device == 'cpu' + assert cpu_instances.scores.device == 'cpu' + + def test_numpy_tensor(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + + np_instances = instances.numpy() + self.check_data_dtype(np_instances, np.ndarray) + + tensor_instances = instances.to_tensor() + self.check_data_dtype(tensor_instances, torch.Tensor) + + def test_repr(self): + metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + instances = BaseDataElement(metainfo=metainfo) + instances.det_labels = torch.LongTensor([0, 1, 2, 3]) + instances.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3]) + assert repr(instances) == ('<BaseDataElement(\n' + ' META INFORMATION\n' + 'img_shape: (800, 1196, 3)\n' + 'pad_shape: (800, 1216, 3)\n' + ' DATA FIELDS\n' + 'shape of det_labels: torch.Size([4])\n' + 'shape of det_scores: torch.Size([4])\n' + ') at 0x7f84acd10f90>') diff --git a/tests/test_data/test_data_sample.py b/tests/test_data/test_data_sample.py new file mode 100644 index 00000000..0d16073b --- /dev/null +++ b/tests/test_data/test_data_sample.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +from functools import partial +from unittest import TestCase + +import numpy as np +import pytest +import torch + +from mmengine.data import BaseDataElement, BaseDataSample + + +class TestBaseDataSample(TestCase): + + def setup_data(self): + metainfo = dict( + img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), random.randint(400, 600))) + gt_instances = BaseDataElement( + data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))) + pred_instances = BaseDataElement( + data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))) + data = dict(gt_instances=gt_instances, pred_instances=pred_instances) + return metainfo, data + + def check_key_value(self, instances, metainfo=None, data=None): + # check the existence of keys in metainfo, data, and instances + if metainfo: + for k, v in metainfo.items(): + assert k in instances + assert k in instances.keys() + assert k in instances.metainfo_keys() + assert k not in instances.data_keys() + assert instances.get(k) == v + assert getattr(instances, k) == v + if data: + for k, v in data.items(): + assert k in instances + assert k in instances.keys() + assert k not in instances.metainfo_keys() + assert k in instances.data_keys() + assert instances.get(k) == v + assert getattr(instances, k) == v + + def check_data_device(self, instances, device): + assert instances.device == device + for v in instances.data_values(): + if isinstance(v, torch.Tensor): + assert v.device == device + elif isinstance(v, BaseDataElement): + self.check_data_device(v, device) + + def check_data_dtype(self, instances, dtype): + for v in instances.data_values(): + if isinstance(v, (torch.Tensor, np.ndarray)): + assert isinstance(v, dtype) + if isinstance(v, BaseDataElement): + self.check_data_dtype(v, dtype) + + def test_init(self): + # initialization with no data and metainfo + metainfo, data = self.setup_data() + instances = BaseDataSample() + for k in metainfo: + assert k not in instances + assert instances.get(k, None) is None + for k in data: + assert k not in instances + assert instances.get(k, 'abc') == 'abc' + + # initialization with kwargs + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo=metainfo, data=data) + self.check_key_value(instances, metainfo, data) + + # initialization with args + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + self.check_key_value(instances, metainfo, data) + + # initialization with args + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo=metainfo) + self.check_key_value(instances, metainfo) + instances = BaseDataSample(data=data) + self.check_key_value(instances, data=data) + + def test_new(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo=metainfo, data=data) + + # test new() with no arguments + new_instances = instances.new() + assert type(new_instances) == type(instances) + assert id(new_instances.data) != id(instances.data) + assert id(new_instances.bboxes) != id(data) + self.check_key_value(new_instances, metainfo, data) + + # test new() with arguments + metainfo, data = self.setup_data() + new_instances = instances.new(metainfo=metainfo, data=data) + assert type(new_instances) == type(instances) + assert id(new_instances.data) != id(instances.data) + assert id(new_instances.data) != id(data) + self.check_key_value(new_instances, metainfo, data) + + def test_set_metainfo(self): + metainfo, _ = self.setup_data() + instances = BaseDataSample() + instances.set_metainfo(metainfo) + self.check_key_value(instances, metainfo=metainfo) + + # test setting existing keys and new keys + new_metainfo, _ = self.setup_data() + new_metainfo.update(other=123) + instances.set_metainfo(new_metainfo) + self.check_key_value(instances, metainfo=new_metainfo) + + def test_set_data(self): + metainfo, data = self.setup_data() + instances = BaseDataSample() + + instances.gt_instances = data['gt_instances'] + instances.pred_instances = data['pred_instances'] + self.check_key_value(instances, data=data) + + # a.xx only set data rather than metainfo + instances.img_shape = metainfo['img_shape'] + instances.img_id = metainfo['img_id'] + self.check_key_value(instances, data=metainfo) + + def test_delete_modify(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + + new_metainfo, new_data = self.setup_data() + instances.gt_instances = new_data['gt_instances'] + instances.pred_instances = new_data['pred_instances'] + + # a.xx only set data rather than metainfo + instances.set_metainfo(new_metainfo) + self.check_key_value(instances, new_metainfo, new_data) + + assert instances.gt_instances != data['gt_instances'] + assert instances.pred_instances != data['pred_instances'] + assert instances.img_id != metainfo['img_id'] + assert instances.img_shape != metainfo['img_shape'] + + del instances.gt_instances + assert instances.pop('pred_instances', + None) == new_data['pred_instances'] + + # pred_instances has been deleted, + # instances does not have the pred_instances + with self.assertRaises(AttributeError): + del instances.pred_instances + + assert 'gt_instances' not in instances + assert 'pred_instances' not in instances + assert instances.pop('gt_instances', None) is None + assert instances.pop('pred_instances', 'abcdef') == 'abcdef' + + @pytest.mark.skipif( + not torch.cuda.is_available(), reason='GPU is required!') + def test_cuda(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + + cuda_instances = instances.cuda() + self.check_data_device(instances, 'cuda:0') + + # here we further test to convert from cuda to cpu + cpu_instances = cuda_instances.cpu() + self.check_data_device(cpu_instances, 'cpu') + del cuda_instances + + cuda_instances = instances.to('cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + def test_cpu(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + self.check_data_device(instances, 'cpu') + + cpu_instances = instances.cpu() + assert cpu_instances.device == 'cpu' + assert cpu_instances.bboxes.device == 'cpu' + assert cpu_instances.scores.device == 'cpu' + + def test_numpy_tensor(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + + np_instances = instances.numpy() + self.check_data_dtype(np_instances, np.ndarray) + + tensor_instances = instances.to_tensor() + self.check_data_dtype(tensor_instances, torch.Tensor) + + def test_repr(self): + metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + gt_instances = BaseDataElement( + data=dict( + det_labels=torch.LongTensor([0, 1, 2, 3], + det_scores=torch.Tensor( + [0.01, 0.1, 0.2, 0.3])))) + data = dict(gt_instances=gt_instances) + instances = BaseDataSample(metainfo=metainfo, data=data) + assert repr(instances) == ('<BaseDataSample(\n' + ' META INFORMATION\n' + 'img_shape: (800, 1196, 3)\n' + 'pad_shape: (800, 1216, 3)\n' + ' DATA FIELDS\n' + '\tgt_instances: <BaseDataElement(\n' + '\t META INFORMATION\n' + '\timg_shape: (800, 1196, 3)\n' + '\tpad_shape: (800, 1216, 3)\n' + '\t DATA FIELDS\n' + '\tshape of det_labels: torch.Size([4])\n' + '\tshape of det_scores: torch.Size([4])\n' + '\t) at 0x7f84acd10f90>' + ') at 0x7f84acd10f90>') + + def test_set_get_fields(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo) + for key, value in data.items(): + instances._set_field(value, key, BaseDataElement) + self.check_key_value(instances, data=data) + + # test type check + _, data = self.setup_data() + instances = BaseDataSample() + for key, value in data.items(): + with self.assertRaises(AssertionError): + instances._set_field(value, key, BaseDataSample) + + def test_del_field(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo) + for key, value in data.items(): + instances._set_field(value, key, BaseDataElement) + instances._del_field('gt_instances') + instances._del_field('pred_instances') + + # gt_instance has been deleted, instances does not have the gt_instance + with self.assertRaises(AttributeError): + instances._del_field('gt_instances') + assert 'gt_instances' not in instances + assert 'pred_instances' not in instances + + def test_inheritance(self): + + class DetDataSample(BaseDataSample): + proposals = property( + fget=partial(BaseDataSample._get_field, name='_proposals'), + fset=partial( + BaseDataSample._set_field, + name='_proposals', + dtype=BaseDataElement), + fdel=partial(BaseDataSample._del_field, name='_proposals'), + doc='Region proposals of an image') + gt_instances = property( + fget=partial(BaseDataSample._get_field, name='_gt_instances'), + fset=partial( + BaseDataSample._set_field, + name='_gt_instances', + dtype=BaseDataElement), + fdel=partial(BaseDataSample._del_field, name='_gt_instances'), + doc='Ground truth instances of an image') + pred_instances = property( + fget=partial( + BaseDataSample._get_field, name='_pred_instances'), + fset=partial( + BaseDataSample._set_field, + name='_pred_instances', + dtype=BaseDataElement), + fdel=partial( + BaseDataSample._del_field, name='_pred_instances'), + doc='Predicted instances of an image') + + det_sample = DetDataSample() + + # test set + proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4)))) + det_sample.proposals = proposals + assert 'proposals' in det_sample + + # test get + assert det_sample.proposals == proposals + + # test delete + del det_sample.proposals + assert 'proposals' not in det_sample + + # test the data whether meet the requirements + with self.assertRaises(AssertionError): + det_sample.proposals = torch.rand((5, 4)) -- GitLab