# Copyright (c) OpenMMLab. All rights reserved. import itertools import random from unittest import TestCase import numpy as np import pytest import torch from mmengine.data import BaseDataElement, InstanceData class TmpObject: def __init__(self, tmp) -> None: assert isinstance(tmp, list) if len(tmp) > 0: for t in tmp: assert isinstance(t, list) self.tmp = tmp def __len__(self): return len(self.tmp) def __getitem__(self, item): if type(item) == int: if item >= len(self) or item < -len(self): # type:ignore raise IndexError(f'Index {item} out of range!') else: # keep the dimension item = slice(item, None, len(self)) return TmpObject(self.tmp[item]) @staticmethod def cat(tmp_objs): assert all(isinstance(results, TmpObject) for results in tmp_objs) if len(tmp_objs) == 1: return tmp_objs[0] tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] tmp_list = list(itertools.chain(*tmp_list)) new_data = TmpObject(tmp_list) return new_data def __repr__(self): return str(self.tmp) class TmpObjectWithoutCat: def __init__(self, tmp) -> None: assert isinstance(tmp, list) if len(tmp) > 0: for t in tmp: assert isinstance(t, list) self.tmp = tmp def __len__(self): return len(self.tmp) def __getitem__(self, item): if type(item) == int: if item >= len(self) or item < -len(self): # type:ignore raise IndexError(f'Index {item} out of range!') else: # keep the dimension item = slice(item, None, len(self)) return TmpObject(self.tmp[item]) def __repr__(self): return str(self.tmp) class TestInstanceData(TestCase): def setup_data(self): metainfo = dict( img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) instances_infos = [1] * 5 bboxes = torch.rand((5, 4)) labels = np.random.rand(5) kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]] ids = (1, 2, 3, 4, 5) name_ids = '12345' polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist()) instance_data = InstanceData( metainfo=metainfo, bboxes=bboxes, labels=labels, polygons=polygons, kps=kps, ids=ids, name_ids=name_ids, instances_infos=instances_infos) return instance_data def test_set_data(self): instance_data = self.setup_data() # test set '_metainfo_fields' or '_data_fields' with self.assertRaises(AttributeError): instance_data._metainfo_fields = 1 with self.assertRaises(AttributeError): instance_data._data_fields = 1 # The data length in InstanceData must be the same with self.assertRaises(AssertionError): instance_data.keypoints = torch.rand((17, 2)) instance_data.keypoints = torch.rand((5, 2)) assert 'keypoints' in instance_data def test_getitem(self): instance_data = InstanceData() # length must be greater than 0 with self.assertRaises(IndexError): instance_data[1] instance_data = self.setup_data() assert len(instance_data) == 5 slice_instance_data = instance_data[:2] assert len(slice_instance_data) == 2 slice_instance_data = instance_data[1] assert len(slice_instance_data) == 1 # assert the index should in 0 ~ len(instance_data) -1 with pytest.raises(IndexError): instance_data[5] # isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor) item = torch.Tensor([1, 2, 3, 4]) # float with pytest.raises(AssertionError): instance_data[item] # when input is a bool tensor, The shape of # the input at index 0 should equal to # the value length in instance_data_field with pytest.raises(AssertionError): instance_data[item.bool()] # test Longtensor long_tensor = torch.randint(5, (2, )) long_index_instance_data = instance_data[long_tensor] assert len(long_index_instance_data) == len(long_tensor) # test bool tensor bool_tensor = torch.rand(5) > 0.5 bool_index_instance_data = instance_data[bool_tensor] assert len(bool_index_instance_data) == bool_tensor.sum() bool_tensor = torch.rand(5) > 1 empty_instance_data = instance_data[bool_tensor] assert len(empty_instance_data) == bool_tensor.sum() # test list index list_index = [1, 2] list_index_instance_data = instance_data[list_index] assert len(list_index_instance_data) == len(list_index) # text list bool list_bool = [True, False, True, False, False] list_bool_instance_data = instance_data[list_bool] assert len(list_bool_instance_data) == 2 # test numpy long_numpy = np.random.randint(5, size=2) long_numpy_instance_data = instance_data[long_numpy] assert len(long_numpy_instance_data) == len(long_numpy) bool_numpy = np.random.rand(5) > 0.5 bool_numpy_instance_data = instance_data[bool_numpy] assert len(bool_numpy_instance_data) == bool_numpy.sum() # without cat instance_data.polygons = TmpObjectWithoutCat( np.arange(25).reshape((5, -1)).tolist()) bool_numpy = np.random.rand(5) > 0.5 with pytest.raises( ValueError, match=('The type of `polygons` is ' f'`{type(instance_data.polygons)}`, ' 'which has no attribute of `cat`, so it does not ' f'support slice with `bool`')): bool_numpy_instance_data = instance_data[bool_numpy] def test_cat(self): instance_data_1 = self.setup_data() instance_data_2 = self.setup_data() cat_instance_data = InstanceData.cat( [instance_data_1, instance_data_2]) assert len(cat_instance_data) == 10 # All inputs must be InstanceData instance_data_2 = BaseDataElement( bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) with self.assertRaises(AssertionError): InstanceData.cat([instance_data_1, instance_data_2]) # Input List length must be greater than 0 with self.assertRaises(AssertionError): InstanceData.cat([]) instance_data_2 = instance_data_1.clone() instance_data_2 = instance_data_2[torch.zeros(5) > 0.5] cat_instance_data = InstanceData.cat( [instance_data_1, instance_data_2]) cat_instance_data = InstanceData.cat([instance_data_1]) assert len(cat_instance_data) == 5 # test custom data cat instance_data_1.polygons = TmpObjectWithoutCat( np.arange(25).reshape((5, -1)).tolist()) instance_data_2 = instance_data_1.clone() with pytest.raises( ValueError, match=('The type of `polygons` is ' f'`{type(instance_data_1.polygons)}` ' 'which has no attribute of `cat`')): cat_instance_data = InstanceData.cat( [instance_data_1, instance_data_2]) def test_len(self): instance_data = self.setup_data() assert len(instance_data) == 5 instance_data = InstanceData() assert len(instance_data) == 0