From 41e1191cbc47f84a0598732baa9ea1c32e315cc7 Mon Sep 17 00:00:00 2001
From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Date: Wed, 16 Feb 2022 23:19:18 +0800
Subject: [PATCH] add unit tests of data abstract interface (#21)

* add unit tests of data abstract interface

* update

* update

* update docs of data element

* a draft of UT of datasample, to be finished

* update datasample test

* updata

* update

* fix comments

* fix comments

* fix comments

Co-authored-by: liukuikun <liukuikun@sensetime.com>
---
 .../tutorials/abstract_data_interface.md      |  11 +-
 tests/test_data/test_data_element.py          | 204 ++++++++++++
 tests/test_data/test_data_sample.py           | 298 ++++++++++++++++++
 3 files changed, 508 insertions(+), 5 deletions(-)
 create mode 100644 tests/test_data/test_data_element.py
 create mode 100644 tests/test_data/test_data_sample.py

diff --git a/docs/zh_cn/tutorials/abstract_data_interface.md b/docs/zh_cn/tutorials/abstract_data_interface.md
index ce85d7d7..6f2e2905 100644
--- a/docs/zh_cn/tutorials/abstract_data_interface.md
+++ b/docs/zh_cn/tutorials/abstract_data_interface.md
@@ -113,13 +113,13 @@ gt_instances2 = gt_instances1.new()
 gt_instances = BaseDataElement()
 # 设置 gt_instances 的 meta 字段,img_id 和 img_shape 会被作为 metainfo 的字段成为 gt_instances 的属性
 gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))
-assert 'img_shape' in gt_instaces.metainfo_keys()
+assert 'img_shape' in gt_instances.metainfo_keys()
 # 'img_shape' 是 gt_instances 的属性
-assert 'img_shape' in gt_instaces
+assert 'img_shape' in gt_instances
 # img_shape 不是 gt_instances 的 data 字段
-assert 'img_shape' not in gt_instaces.data_keys()
+assert 'img_shape' not in gt_instances.data_keys()
 # 通过 keys 来访问所有属性
-assert 'img_shape' in gt_instaces.keys()
+assert 'img_shape' in gt_instances.keys()
 # 访问类属性一样访问 'img_shape'
 print(gt_instances.img_shape)
 
@@ -179,7 +179,7 @@ gt_instances.get('bboxes', None)    # 6x4 tensor
 # 属性的删除
 del gt_instances.img_shape
 del gt_instances.bboxes
-assert 'img_shape' in gt_instances
+assert 'img_shape' not in gt_instances
 assert 'bboxes' not in gt_instances
 
 # 提供了便捷的属性删除和访问操作 pop
@@ -191,6 +191,7 @@ gt_instances.pop('bboxes', None)  # None
 
 用户可以像 torch.Tensor 那样对 `BaseDataElement` 的 data 进行状态转换,目前支持 `cuda`, `cpu`, `to`, `numpy` 等操作。
 其中,`to` 函数拥有和 `torch.Tensor.to()` 相同的接口,使得用户可以灵活地将被封装的 tensor 进行状态转换。
+**注意:** 这些接口只会处理类型为 np.array,torch.Tensor,或者数字的序列,其他属性的数据(如字符串)会被跳过处理。
 
 ```python
 # 将所有 data 转移到 GPU 上
diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py
new file mode 100644
index 00000000..4bf8c722
--- /dev/null
+++ b/tests/test_data/test_data_element.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from unittest import TestCase
+
+import numpy as np
+import pytest
+import torch
+
+from mmengine.data import BaseDataElement
+
+
+class TestBaseDataElement(TestCase):
+
+    def setup_data(self):
+        metainfo = dict(
+            img_id=random.randint(0, 100),
+            img_shape=(random.randint(400, 600), random.randint(400, 600)))
+        data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))
+        return metainfo, data
+
+    def check_key_value(self, instances, metainfo=None, data=None):
+        # check the existence of keys in metainfo, data, and instances
+        if metainfo:
+            for k, v in metainfo.items():
+                assert k in instances
+                assert k in instances.keys()
+                assert k in instances.metainfo_keys()
+                assert k not in instances.data_keys()
+                assert instances.get(k) == v
+                assert getattr(instances, k) == v
+        if data:
+            for k, v in data.items():
+                assert k in instances
+                assert k in instances.keys()
+                assert k not in instances.metainfo_keys()
+                assert k in instances.data_keys()
+                assert instances.get(k) == v
+                assert getattr(instances, k) == v
+
+    def check_data_device(self, instances, device):
+        assert instances.device == device
+        for v in instances.data_values():
+            if isinstance(v, torch.Tensor):
+                assert v.device == device
+
+    def check_data_dtype(self, instances, dtype):
+        for v in instances.data_values():
+            if isinstance(v, (torch.Tensor, np.ndarray)):
+                assert isinstance(v, dtype)
+
+    def test_init(self):
+        # initialization with no data and metainfo
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement()
+        for k in metainfo:
+            assert k not in instances
+            assert instances.get(k, None) is None
+        for k in data:
+            assert k not in instances
+            assert instances.get(k, 'abc') == 'abc'
+
+        # initialization with kwargs
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo=metainfo, data=data)
+        self.check_key_value(instances, metainfo, data)
+
+        # initialization with args
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+        self.check_key_value(instances, metainfo, data)
+
+        # initialization with args
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo=metainfo)
+        self.check_key_value(instances, metainfo)
+        instances = BaseDataElement(data=data)
+        self.check_key_value(instances, data=data)
+
+    def test_new(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo=metainfo, data=data)
+
+        # test new() with no arguments
+        new_instances = instances.new()
+        assert type(new_instances) == type(instances)
+        assert id(new_instances.bboxes) != id(instances.bboxes)
+        assert id(new_instances.bboxes) != id(data['bboxes'])
+        self.check_key_value(new_instances, metainfo, data)
+
+        # test new() with arguments
+        metainfo, data = self.setup_data()
+        new_instances = instances.new(metainfo=metainfo, data=data)
+        assert type(new_instances) == type(instances)
+        assert id(new_instances.bboxes) != id(instances.bboxes)
+        assert id(new_instances.bboxes) != id(data['bboxes'])
+        self.check_key_value(new_instances, metainfo, data)
+
+    def test_set_metainfo(self):
+        metainfo, _ = self.setup_data()
+        instances = BaseDataElement()
+        instances.set_metainfo(metainfo)
+        self.check_key_value(instances, metainfo=metainfo)
+
+        # test setting existing keys and new keys
+        new_metainfo, _ = self.setup_data()
+        new_metainfo.update(other=123)
+        instances.set_metainfo(new_metainfo)
+        self.check_key_value(instances, metainfo=new_metainfo)
+
+    def test_set_data(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement()
+
+        instances.bboxes = data['bboxes']
+        instances.scores = data['scores']
+        self.check_key_value(instances, data=data)
+
+        # a.xx only set data rather than metainfo
+        instances.img_shape = metainfo['img_shape']
+        instances.img_id = metainfo['img_id']
+        self.check_key_value(instances, data=metainfo)
+
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+        with self.assertRaises(AssertionError):
+            instances.img_shape = metainfo['img_shape']
+
+    def test_delete_modify(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+
+        new_metainfo, new_data = self.setup_data()
+        instances.bboxes = new_data['bboxes']
+        instances.scores = new_data['scores']
+
+        # a.xx only set data rather than metainfo
+        instances.set_metainfo(new_metainfo)
+        self.check_key_value(instances, new_metainfo, new_data)
+
+        assert instances.bboxes != data['bboxes']
+        assert instances.scores != data['scores']
+        assert instances.img_id != metainfo['img_id']
+        assert instances.img_shape != metainfo['img_shape']
+
+        del instances.bboxes
+        assert instances.pop('scores', None) == new_data['scores']
+        with self.assertRaises(AttributeError):
+            del instances.scores
+
+        assert 'bboxes' not in instances
+        assert 'scores' not in instances
+        assert instances.pop('bboxes', None) is None
+        assert instances.pop('scores', 'abcdef') == 'abcdef'
+
+    @pytest.mark.skipif(
+        not torch.cuda.is_available(), reason='GPU is required!')
+    def test_cuda(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+
+        cuda_instances = instances.cuda()
+        self.check_data_device(instances, 'cuda:0')
+
+        # here we further test to convert from cuda to cpu
+        cpu_instances = cuda_instances.cpu()
+        self.check_data_device(cpu_instances, 'cpu')
+        del cuda_instances
+
+        cuda_instances = instances.to('cuda:0')
+        self.check_data_device(cuda_instances, 'cuda:0')
+
+    def test_cpu(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+        self.check_data_device(instances, 'cpu')
+
+        cpu_instances = instances.cpu()
+        assert cpu_instances.device == 'cpu'
+        assert cpu_instances.bboxes.device == 'cpu'
+        assert cpu_instances.scores.device == 'cpu'
+
+    def test_numpy_tensor(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataElement(metainfo, data)
+
+        np_instances = instances.numpy()
+        self.check_data_dtype(np_instances, np.ndarray)
+
+        tensor_instances = instances.to_tensor()
+        self.check_data_dtype(tensor_instances, torch.Tensor)
+
+    def test_repr(self):
+        metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
+        instances = BaseDataElement(metainfo=metainfo)
+        instances.det_labels = torch.LongTensor([0, 1, 2, 3])
+        instances.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3])
+        assert repr(instances) == ('<BaseDataElement(\n'
+                                   '  META INFORMATION\n'
+                                   'img_shape: (800, 1196, 3)\n'
+                                   'pad_shape: (800, 1216, 3)\n'
+                                   '  DATA FIELDS\n'
+                                   'shape of det_labels: torch.Size([4])\n'
+                                   'shape of det_scores: torch.Size([4])\n'
+                                   ') at 0x7f84acd10f90>')
diff --git a/tests/test_data/test_data_sample.py b/tests/test_data/test_data_sample.py
new file mode 100644
index 00000000..0d16073b
--- /dev/null
+++ b/tests/test_data/test_data_sample.py
@@ -0,0 +1,298 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from functools import partial
+from unittest import TestCase
+
+import numpy as np
+import pytest
+import torch
+
+from mmengine.data import BaseDataElement, BaseDataSample
+
+
+class TestBaseDataSample(TestCase):
+
+    def setup_data(self):
+        metainfo = dict(
+            img_id=random.randint(0, 100),
+            img_shape=(random.randint(400, 600), random.randint(400, 600)))
+        gt_instances = BaseDataElement(
+            data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
+        pred_instances = BaseDataElement(
+            data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
+        data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
+        return metainfo, data
+
+    def check_key_value(self, instances, metainfo=None, data=None):
+        # check the existence of keys in metainfo, data, and instances
+        if metainfo:
+            for k, v in metainfo.items():
+                assert k in instances
+                assert k in instances.keys()
+                assert k in instances.metainfo_keys()
+                assert k not in instances.data_keys()
+                assert instances.get(k) == v
+                assert getattr(instances, k) == v
+        if data:
+            for k, v in data.items():
+                assert k in instances
+                assert k in instances.keys()
+                assert k not in instances.metainfo_keys()
+                assert k in instances.data_keys()
+                assert instances.get(k) == v
+                assert getattr(instances, k) == v
+
+    def check_data_device(self, instances, device):
+        assert instances.device == device
+        for v in instances.data_values():
+            if isinstance(v, torch.Tensor):
+                assert v.device == device
+            elif isinstance(v, BaseDataElement):
+                self.check_data_device(v, device)
+
+    def check_data_dtype(self, instances, dtype):
+        for v in instances.data_values():
+            if isinstance(v, (torch.Tensor, np.ndarray)):
+                assert isinstance(v, dtype)
+            if isinstance(v, BaseDataElement):
+                self.check_data_dtype(v, dtype)
+
+    def test_init(self):
+        # initialization with no data and metainfo
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample()
+        for k in metainfo:
+            assert k not in instances
+            assert instances.get(k, None) is None
+        for k in data:
+            assert k not in instances
+            assert instances.get(k, 'abc') == 'abc'
+
+        # initialization with kwargs
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo=metainfo, data=data)
+        self.check_key_value(instances, metainfo, data)
+
+        # initialization with args
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo, data)
+        self.check_key_value(instances, metainfo, data)
+
+        # initialization with args
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo=metainfo)
+        self.check_key_value(instances, metainfo)
+        instances = BaseDataSample(data=data)
+        self.check_key_value(instances, data=data)
+
+    def test_new(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo=metainfo, data=data)
+
+        # test new() with no arguments
+        new_instances = instances.new()
+        assert type(new_instances) == type(instances)
+        assert id(new_instances.data) != id(instances.data)
+        assert id(new_instances.bboxes) != id(data)
+        self.check_key_value(new_instances, metainfo, data)
+
+        # test new() with arguments
+        metainfo, data = self.setup_data()
+        new_instances = instances.new(metainfo=metainfo, data=data)
+        assert type(new_instances) == type(instances)
+        assert id(new_instances.data) != id(instances.data)
+        assert id(new_instances.data) != id(data)
+        self.check_key_value(new_instances, metainfo, data)
+
+    def test_set_metainfo(self):
+        metainfo, _ = self.setup_data()
+        instances = BaseDataSample()
+        instances.set_metainfo(metainfo)
+        self.check_key_value(instances, metainfo=metainfo)
+
+        # test setting existing keys and new keys
+        new_metainfo, _ = self.setup_data()
+        new_metainfo.update(other=123)
+        instances.set_metainfo(new_metainfo)
+        self.check_key_value(instances, metainfo=new_metainfo)
+
+    def test_set_data(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample()
+
+        instances.gt_instances = data['gt_instances']
+        instances.pred_instances = data['pred_instances']
+        self.check_key_value(instances, data=data)
+
+        # a.xx only set data rather than metainfo
+        instances.img_shape = metainfo['img_shape']
+        instances.img_id = metainfo['img_id']
+        self.check_key_value(instances, data=metainfo)
+
+    def test_delete_modify(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo, data)
+
+        new_metainfo, new_data = self.setup_data()
+        instances.gt_instances = new_data['gt_instances']
+        instances.pred_instances = new_data['pred_instances']
+
+        # a.xx only set data rather than metainfo
+        instances.set_metainfo(new_metainfo)
+        self.check_key_value(instances, new_metainfo, new_data)
+
+        assert instances.gt_instances != data['gt_instances']
+        assert instances.pred_instances != data['pred_instances']
+        assert instances.img_id != metainfo['img_id']
+        assert instances.img_shape != metainfo['img_shape']
+
+        del instances.gt_instances
+        assert instances.pop('pred_instances',
+                             None) == new_data['pred_instances']
+
+        # pred_instances has been deleted,
+        # instances does not have the pred_instances
+        with self.assertRaises(AttributeError):
+            del instances.pred_instances
+
+        assert 'gt_instances' not in instances
+        assert 'pred_instances' not in instances
+        assert instances.pop('gt_instances', None) is None
+        assert instances.pop('pred_instances', 'abcdef') == 'abcdef'
+
+    @pytest.mark.skipif(
+        not torch.cuda.is_available(), reason='GPU is required!')
+    def test_cuda(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo, data)
+
+        cuda_instances = instances.cuda()
+        self.check_data_device(instances, 'cuda:0')
+
+        # here we further test to convert from cuda to cpu
+        cpu_instances = cuda_instances.cpu()
+        self.check_data_device(cpu_instances, 'cpu')
+        del cuda_instances
+
+        cuda_instances = instances.to('cuda:0')
+        self.check_data_device(cuda_instances, 'cuda:0')
+
+    def test_cpu(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo, data)
+        self.check_data_device(instances, 'cpu')
+
+        cpu_instances = instances.cpu()
+        assert cpu_instances.device == 'cpu'
+        assert cpu_instances.bboxes.device == 'cpu'
+        assert cpu_instances.scores.device == 'cpu'
+
+    def test_numpy_tensor(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo, data)
+
+        np_instances = instances.numpy()
+        self.check_data_dtype(np_instances, np.ndarray)
+
+        tensor_instances = instances.to_tensor()
+        self.check_data_dtype(tensor_instances, torch.Tensor)
+
+    def test_repr(self):
+        metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
+        gt_instances = BaseDataElement(
+            data=dict(
+                det_labels=torch.LongTensor([0, 1, 2, 3],
+                                            det_scores=torch.Tensor(
+                                                [0.01, 0.1, 0.2, 0.3]))))
+        data = dict(gt_instances=gt_instances)
+        instances = BaseDataSample(metainfo=metainfo, data=data)
+        assert repr(instances) == ('<BaseDataSample(\n'
+                                   '  META INFORMATION\n'
+                                   'img_shape: (800, 1196, 3)\n'
+                                   'pad_shape: (800, 1216, 3)\n'
+                                   '  DATA FIELDS\n'
+                                   '\tgt_instances: <BaseDataElement(\n'
+                                   '\t  META INFORMATION\n'
+                                   '\timg_shape: (800, 1196, 3)\n'
+                                   '\tpad_shape: (800, 1216, 3)\n'
+                                   '\t  DATA FIELDS\n'
+                                   '\tshape of det_labels: torch.Size([4])\n'
+                                   '\tshape of det_scores: torch.Size([4])\n'
+                                   '\t) at 0x7f84acd10f90>'
+                                   ') at 0x7f84acd10f90>')
+
+    def test_set_get_fields(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo)
+        for key, value in data.items():
+            instances._set_field(value, key, BaseDataElement)
+        self.check_key_value(instances, data=data)
+
+        # test type check
+        _, data = self.setup_data()
+        instances = BaseDataSample()
+        for key, value in data.items():
+            with self.assertRaises(AssertionError):
+                instances._set_field(value, key, BaseDataSample)
+
+    def test_del_field(self):
+        metainfo, data = self.setup_data()
+        instances = BaseDataSample(metainfo)
+        for key, value in data.items():
+            instances._set_field(value, key, BaseDataElement)
+        instances._del_field('gt_instances')
+        instances._del_field('pred_instances')
+
+        # gt_instance has been deleted, instances does not have the gt_instance
+        with self.assertRaises(AttributeError):
+            instances._del_field('gt_instances')
+        assert 'gt_instances' not in instances
+        assert 'pred_instances' not in instances
+
+    def test_inheritance(self):
+
+        class DetDataSample(BaseDataSample):
+            proposals = property(
+                fget=partial(BaseDataSample._get_field, name='_proposals'),
+                fset=partial(
+                    BaseDataSample._set_field,
+                    name='_proposals',
+                    dtype=BaseDataElement),
+                fdel=partial(BaseDataSample._del_field, name='_proposals'),
+                doc='Region proposals of an image')
+            gt_instances = property(
+                fget=partial(BaseDataSample._get_field, name='_gt_instances'),
+                fset=partial(
+                    BaseDataSample._set_field,
+                    name='_gt_instances',
+                    dtype=BaseDataElement),
+                fdel=partial(BaseDataSample._del_field, name='_gt_instances'),
+                doc='Ground truth instances of an image')
+            pred_instances = property(
+                fget=partial(
+                    BaseDataSample._get_field, name='_pred_instances'),
+                fset=partial(
+                    BaseDataSample._set_field,
+                    name='_pred_instances',
+                    dtype=BaseDataElement),
+                fdel=partial(
+                    BaseDataSample._del_field, name='_pred_instances'),
+                doc='Predicted instances of an image')
+
+        det_sample = DetDataSample()
+
+        # test set
+        proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4))))
+        det_sample.proposals = proposals
+        assert 'proposals' in det_sample
+
+        # test get
+        assert det_sample.proposals == proposals
+
+        # test delete
+        del det_sample.proposals
+        assert 'proposals' not in det_sample
+
+        # test the data whether meet the requirements
+        with self.assertRaises(AssertionError):
+            det_sample.proposals = torch.rand((5, 4))
-- 
GitLab