Skip to content
Snippets Groups Projects
Unverified Commit 5c5c03e6 authored by liukuikun's avatar liukuikun Committed by GitHub
Browse files

[Enchance] cat empty instancedata, support torch.bool for more type (#209)

* refactor instancedata

* fix docs

* fix comment
parent 16058fdb
No related branches found
No related tags found
No related merge requests found
...@@ -7,10 +7,11 @@ import torch ...@@ -7,10 +7,11 @@ import torch
class BaseDataElement: class BaseDataElement:
"""A base data structure interface of OpenMMlab. """A base data interface that supports Tensor-like and dict-like
operations.
Data elements refer to predicted results or ground truth labels on a A typical data elements refer to predicted results or ground truth labels
task, such as predicted bboxes, instance masks, semantic on a task, such as predicted bboxes, instance masks, semantic
segmentation masks, etc. Because groundtruth labels and predicted results segmentation masks, etc. Because groundtruth labels and predicted results
often have similar properties (for example, the predicted bboxes and the often have similar properties (for example, the predicted bboxes and the
groundtruth bboxes), MMEngine uses the same abstract data interface to groundtruth bboxes), MMEngine uses the same abstract data interface to
...@@ -23,7 +24,23 @@ class BaseDataElement: ...@@ -23,7 +24,23 @@ class BaseDataElement:
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and
``LabelData`` inheriting from ``BaseDataElement`` to represent different ``LabelData`` inheriting from ``BaseDataElement`` to represent different
types of ground truth labels or predictions. types of ground truth labels or predictions.
They are used as interfaces between different commopenets.
Another common data element is sample data. A sample data consists of input
data (such as an image) and its annotations and predictions. In general,
an image can have multiple types of annotations and/or predictions at the
same time (for example, both pixel-level semantic segmentation annotations
and instance-level detection bboxes annotations). All labels and
predictions of a training sample are often passed between Dataset, Model,
Visualizer, and Evaluator components. In order to simplify the interface
between components, we can treat them as a large data element and
encapsulate them. Such data elements are generally called XXDataSample in
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement`
allows `BaseDataElement` as its attribute. Such a class generally
encapsulates all the data of a sample in the algorithm library, and its
attributes generally are various types of data elements. For example,
MMDetection is assigned by the BaseDataElement to encapsulate all the data
elements of the sample labeling and prediction of a sample in the
algorithm library.
The attributes in ``BaseDataElement`` are divided into two parts, The attributes in ``BaseDataElement`` are divided into two parts,
the ``metainfo`` and the ``data`` respectively. the ``metainfo`` and the ``data`` respectively.
...@@ -70,8 +87,8 @@ class BaseDataElement: ...@@ -70,8 +87,8 @@ class BaseDataElement:
>>> # new >>> # new
>>> gt_instances1 = gt_instance.new( >>> gt_instances1 = gt_instance.new(
... metainfo=dict(img_id=1, img_shape=(640, 640)), ... metainfo=dict(img_id=1, img_shape=(640, 640)),
... data=dict(bboxes=torch.rand((5, 4)), ... bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,)))) ... scores=torch.rand((5,)))
>>> gt_instances2 = gt_instances1.new() >>> gt_instances2 = gt_instances1.new()
>>> # add and process property >>> # add and process property
...@@ -241,8 +258,9 @@ class BaseDataElement: ...@@ -241,8 +258,9 @@ class BaseDataElement:
self.set_data(dict(instance.items())) self.set_data(dict(instance.items()))
def new(self, def new(self,
metainfo: dict = None, *,
data: dict = None) -> 'BaseDataElement': metainfo: Optional[dict] = None,
**kwargs) -> 'BaseDataElement':
"""Return a new data element with same type. If ``metainfo`` and """Return a new data element with same type. If ``metainfo`` and
``data`` are None, the new data element will have same metainfo and ``data`` are None, the new data element will have same metainfo and
data. If metainfo or data is not None, the new result will overwrite it data. If metainfo or data is not None, the new result will overwrite it
...@@ -252,8 +270,9 @@ class BaseDataElement: ...@@ -252,8 +270,9 @@ class BaseDataElement:
metainfo (dict, optional): A dict contains the meta information metainfo (dict, optional): A dict contains the meta information
of image, such as ``img_shape``, ``scale_factor``, etc. of image, such as ``img_shape``, ``scale_factor``, etc.
Defaults to None. Defaults to None.
data (dict, optional): A dict contains annotations of image or kwargs (dict): A dict contains annotations of image or
model predictions. Defaults to None. model predictions.
Returns: Returns:
BaseDataElement: a new data element with same type. BaseDataElement: a new data element with same type.
""" """
...@@ -263,8 +282,8 @@ class BaseDataElement: ...@@ -263,8 +282,8 @@ class BaseDataElement:
new_data.set_metainfo(metainfo) new_data.set_metainfo(metainfo)
else: else:
new_data.set_metainfo(dict(self.metainfo_items())) new_data.set_metainfo(dict(self.metainfo_items()))
if data is not None: if kwargs:
new_data.set_data(data) new_data.set_data(kwargs)
else: else:
new_data.set_data(dict(self.items())) new_data.set_data(dict(self.items()))
return new_data return new_data
...@@ -388,7 +407,6 @@ class BaseDataElement: ...@@ -388,7 +407,6 @@ class BaseDataElement:
self._data_fields.remove(item) self._data_fields.remove(item)
# dict-like methods # dict-like methods
__setitem__ = __setattr__
__delitem__ = __delattr__ __delitem__ = __delattr__
def get(self, key, default=None) -> Any: def get(self, key, default=None) -> Any:
...@@ -519,6 +537,7 @@ class BaseDataElement: ...@@ -519,6 +537,7 @@ class BaseDataElement:
} }
def __repr__(self) -> str: def __repr__(self) -> str:
"""represent the object."""
def _addindent(s_: str, num_spaces: int) -> str: def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch` https://github.com/pytorch/ """This func is modified from `pytorch` https://github.com/pytorch/
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from collections.abc import Sized
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
...@@ -7,8 +8,9 @@ import torch ...@@ -7,8 +8,9 @@ import torch
from .base_data_element import BaseDataElement from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor, IndexType = Union[str, slice, int, list, torch.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool] torch.cuda.LongTensor, torch.BoolTensor,
torch.cuda.BoolTensor, np.ndarray]
# Modified from # Modified from
...@@ -19,8 +21,37 @@ class InstanceData(BaseDataElement): ...@@ -19,8 +21,37 @@ class InstanceData(BaseDataElement):
Subclass of :class:`BaseDataElement`. All value in `data_fields` Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.tensor`, `numpy.dnarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
Examples: Examples:
>>> # custom data structure
>>> class TmpObject:
... def __init__(self, tmp) -> None:
... assert isinstance(tmp, list)
... self.tmp = tmp
... def __len__(self):
... return len(self.tmp)
... def __getitem__(self, item):
... if type(item) == int:
... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!')
... else:
... # keep the dimension
... item = slice(item, None, len(self))
... return TmpObject(self.tmp[item])
... @staticmethod
... def cat(tmp_objs):
... assert all(isinstance(results, TmpObject) for results in tmp_objs)
... if len(tmp_objs) == 1:
... return tmp_objs[0]
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
... tmp_list = list(itertools.chain(*tmp_list))
... new_data = TmpObject(tmp_list)
... return new_data
... def __repr__(self):
... return str(self.tmp)
>>> from mmengine.data import InstanceData >>> from mmengine.data import InstanceData
>>> import numpy as np >>> import numpy as np
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
...@@ -30,44 +61,69 @@ class InstanceData(BaseDataElement): ...@@ -30,44 +61,69 @@ class InstanceData(BaseDataElement):
>>> instance_data.det_labels = torch.LongTensor([2, 3]) >>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4)) >>> instance_data.bboxes = torch.rand((2, 4))
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
>>> len(instance_data) >>> len(instance_data)
4 2
>>> print(instance_data) >>> print(instance_data)
<InstanceData( <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1196, 3) pad_shape: (800, 1196, 3)
img_shape: (800, 1216, 3) img_shape: (800, 1216, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([2, 3]) det_labels: tensor([2, 3])
det_scores: tensor([0.8, 0.7000]) det_scores: tensor([0.8, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]]) [0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280> ) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] >>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores >>> sorted_results.det_scores
tensor([0.7000, 0.8000]) tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75]) >>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData( <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1216, 3) pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3) img_shape: (800, 1196, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([0]) det_labels: tensor([2])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) masks: [[11, 21, 31, 41]]
det_scores: tensor([0.8000]) det_scores: tensor([0.8000])
) at 0x7fb5cf6e2790> bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]])
>>> instance_data[instance_data.det_scores > 0.75].det_labels polygons: [[1, 2, 3, 4]]
tensor([0]) ) at 0x7f64ecf0ec40>
>>> instance_data[instance_data.det_scores > 0.75].det_scores >>> print(instance_data[instance_data.det_scores > 1])
tensor([0.8000]) <InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([], dtype=torch.int64)
masks: []
det_scores: tensor([])
bboxes: tensor([], size=(0, 4))
polygons: [[]]
) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data]))
<InstanceData(
META INFORMATION
img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3, 2, 3])
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810],
[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810]])
data:
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
[1, 2, 3, 4], [5, 6, 7, 8]]
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
masks: [[11, 21, 31, 41], [51, 61, 71, 81],
[11, 21, 31, 41], [51, 61, 71, 81]]
) at 0x7f203542feb0>
""" """
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray, def __setattr__(self, name: str, value: Sized):
list]):
"""setattr is only used to set data. """setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length the value must have the attribute of `__len__` and have the same length
...@@ -82,9 +138,8 @@ class InstanceData(BaseDataElement): ...@@ -82,9 +138,8 @@ class InstanceData(BaseDataElement):
f'private attribute, which is immutable. ') f'private attribute, which is immutable. ')
else: else:
assert isinstance(value, (torch.Tensor, np.ndarray, list)), \ assert isinstance(value,
f'Can set {type(value)}, only support' \ Sized), 'value must contain `_len__` attribute'
f' {(torch.Tensor, np.ndarray, list)}'
if len(self) > 0: if len(self) > 0:
assert len(value) == len(self), f'the length of ' \ assert len(value) == len(self), f'the length of ' \
...@@ -95,6 +150,8 @@ class InstanceData(BaseDataElement): ...@@ -95,6 +150,8 @@ class InstanceData(BaseDataElement):
f'{len(self)} ' f'{len(self)} '
super().__setattr__(name, value) super().__setattr__(name, value)
__setitem__ = __setattr__
def __getitem__(self, item: IndexType) -> 'InstanceData': def __getitem__(self, item: IndexType) -> 'InstanceData':
""" """
Args: Args:
...@@ -105,11 +162,13 @@ class InstanceData(BaseDataElement): ...@@ -105,11 +162,13 @@ class InstanceData(BaseDataElement):
Returns: Returns:
obj:`InstanceData`: Corresponding values. obj:`InstanceData`: Corresponding values.
""" """
assert len(self) > 0, ' This is a empty instance' if isinstance(item, list):
item = np.array(item)
if isinstance(item, np.ndarray):
item = torch.from_numpy(item)
assert isinstance( assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor, item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long)) torch.BoolTensor, torch.cuda.BoolTensor))
if isinstance(item, str): if isinstance(item, str):
return getattr(self, item) return getattr(self, item)
...@@ -121,7 +180,7 @@ class InstanceData(BaseDataElement): ...@@ -121,7 +180,7 @@ class InstanceData(BaseDataElement):
# keep the dimension # keep the dimension
item = slice(item, None, len(self)) item = slice(item, None, len(self))
new_data = self.new(data={}) new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, torch.Tensor): if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \ assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.' ' values along the first dimension.'
...@@ -140,17 +199,36 @@ class InstanceData(BaseDataElement): ...@@ -140,17 +199,36 @@ class InstanceData(BaseDataElement):
new_data[k] = v[item] new_data[k] = v[item]
elif isinstance(v, np.ndarray): elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()] new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, list): elif isinstance(
r_list = [] v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')):
# convert to indexes from boolTensor # convert to indexes from boolTensor
if isinstance(item, if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)): (torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(-1) indexes = torch.nonzero(item).view(
-1).cpu().numpy().tolist()
else:
indexes = item.cpu().numpy().tolist()
slice_list = []
if indexes:
for index in indexes:
slice_list.append(slice(index, None, len(v)))
else: else:
indexes = item slice_list.append(slice(None, 0, None))
for index in indexes: r_list = [v[s] for s in slice_list]
r_list.append(v[index]) if isinstance(v, (str, list, tuple)):
new_data[k] = r_list new_value = r_list[0]
for r in r_list[1:]:
new_value = new_value + r
else:
new_value = v.cat(r_list)
new_data[k] = new_value
else:
raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not '
f'support slice with `bool`')
else: else:
# item is a slice # item is a slice
for k, v in self.items(): for k, v in self.items():
...@@ -191,24 +269,30 @@ class InstanceData(BaseDataElement): ...@@ -191,24 +269,30 @@ class InstanceData(BaseDataElement):
'elements in `instances_list` ' \ 'elements in `instances_list` ' \
'have the exact same key ' 'have the exact same key '
new_data = instances_list[0].new(data={}) new_data = instances_list[0].__class__(
metainfo=instances_list[0].metainfo)
for k in instances_list[0].keys(): for k in instances_list[0].keys():
values = [results[k] for results in instances_list] values = [results[k] for results in instances_list]
v0 = values[0] v0 = values[0]
if isinstance(v0, torch.Tensor): if isinstance(v0, torch.Tensor):
values = torch.cat(values, dim=0) new_values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray): elif isinstance(v0, np.ndarray):
values = np.concatenate(values, axis=0) new_values = np.concatenate(values, axis=0)
elif isinstance(v0, list): elif isinstance(v0, (str, list, tuple)):
values = list(itertools.chain(*values)) new_values = v0[:]
for v in values[1:]:
new_values += v
elif hasattr(v0, 'cat'):
new_values = v0.cat(values)
else: else:
raise ValueError( raise ValueError(
f'Can not concat the {k} which is a {type(v0)}') f'The type of `{k}` is `{type(v0)}` which has no '
new_data[k] = values 'attribute of `cat`')
new_data[k] = new_values
return new_data # type:ignore return new_data # type:ignore
def __len__(self) -> int: def __len__(self) -> int:
"""The length of instance data.""" """int: the length of InstanceData"""
if len(self._data_fields) > 0: if len(self._data_fields) > 0:
return len(self.values()[0]) return len(self.values()[0])
else: else:
......
...@@ -112,7 +112,7 @@ class TestBaseDataElement(TestCase): ...@@ -112,7 +112,7 @@ class TestBaseDataElement(TestCase):
# test new() with arguments # test new() with arguments
metainfo, data = self.setup_data() metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data) new_instances = instances.new(metainfo=metainfo, **data)
assert type(new_instances) == type(instances) assert type(new_instances) == type(instances)
assert id(new_instances.gt_instances) != id(instances.gt_instances) assert id(new_instances.gt_instances) != id(instances.gt_instances)
_, new_data = self.setup_data() _, new_data = self.setup_data()
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools
import random import random
from unittest import TestCase from unittest import TestCase
...@@ -9,6 +10,66 @@ import torch ...@@ -9,6 +10,66 @@ import torch
from mmengine.data import BaseDataElement, InstanceData from mmengine.data import BaseDataElement, InstanceData
class TmpObject:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
@staticmethod
def cat(tmp_objs):
assert all(isinstance(results, TmpObject) for results in tmp_objs)
if len(tmp_objs) == 1:
return tmp_objs[0]
tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
tmp_list = list(itertools.chain(*tmp_list))
new_data = TmpObject(tmp_list)
return new_data
def __repr__(self):
return str(self.tmp)
class TmpObjectWithoutCat:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
def __repr__(self):
return str(self.tmp)
class TestInstanceData(TestCase): class TestInstanceData(TestCase):
def setup_data(self): def setup_data(self):
...@@ -18,10 +79,18 @@ class TestInstanceData(TestCase): ...@@ -18,10 +79,18 @@ class TestInstanceData(TestCase):
instances_infos = [1] * 5 instances_infos = [1] * 5
bboxes = torch.rand((5, 4)) bboxes = torch.rand((5, 4))
labels = np.random.rand(5) labels = np.random.rand(5)
kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
ids = (1, 2, 3, 4, 5)
name_ids = '12345'
polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist())
instance_data = InstanceData( instance_data = InstanceData(
metainfo=metainfo, metainfo=metainfo,
bboxes=bboxes, bboxes=bboxes,
labels=labels, labels=labels,
polygons=polygons,
kps=kps,
ids=ids,
name_ids=name_ids,
instances_infos=instances_infos) instances_infos=instances_infos)
return instance_data return instance_data
...@@ -34,10 +103,6 @@ class TestInstanceData(TestCase): ...@@ -34,10 +103,6 @@ class TestInstanceData(TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
instance_data._data_fields = 1 instance_data._data_fields = 1
# value only supports (torch.Tensor, np.ndarray, list)
with self.assertRaises(AssertionError):
instance_data.v = 'value'
# The data length in InstanceData must be the same # The data length in InstanceData must be the same
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
instance_data.keypoints = torch.rand((17, 2)) instance_data.keypoints = torch.rand((17, 2))
...@@ -48,14 +113,15 @@ class TestInstanceData(TestCase): ...@@ -48,14 +113,15 @@ class TestInstanceData(TestCase):
def test_getitem(self): def test_getitem(self):
instance_data = InstanceData() instance_data = InstanceData()
# length must be greater than 0 # length must be greater than 0
with self.assertRaises(AssertionError): with self.assertRaises(IndexError):
instance_data[1] instance_data[1]
instance_data = self.setup_data() instance_data = self.setup_data()
assert len(instance_data) == 5 assert len(instance_data) == 5
slice_instance_data = instance_data[:2] slice_instance_data = instance_data[:2]
assert len(slice_instance_data) == 2 assert len(slice_instance_data) == 2
slice_instance_data = instance_data[1]
assert len(slice_instance_data) == 1
# assert the index should in 0 ~ len(instance_data) -1 # assert the index should in 0 ~ len(instance_data) -1
with pytest.raises(IndexError): with pytest.raises(IndexError):
instance_data[5] instance_data[5]
...@@ -80,6 +146,40 @@ class TestInstanceData(TestCase): ...@@ -80,6 +146,40 @@ class TestInstanceData(TestCase):
bool_tensor = torch.rand(5) > 0.5 bool_tensor = torch.rand(5) > 0.5
bool_index_instance_data = instance_data[bool_tensor] bool_index_instance_data = instance_data[bool_tensor]
assert len(bool_index_instance_data) == bool_tensor.sum() assert len(bool_index_instance_data) == bool_tensor.sum()
bool_tensor = torch.rand(5) > 1
empty_instance_data = instance_data[bool_tensor]
assert len(empty_instance_data) == bool_tensor.sum()
# test list index
list_index = [1, 2]
list_index_instance_data = instance_data[list_index]
assert len(list_index_instance_data) == len(list_index)
# text list bool
list_bool = [True, False, True, False, False]
list_bool_instance_data = instance_data[list_bool]
assert len(list_bool_instance_data) == 2
# test numpy
long_numpy = np.random.randint(5, size=2)
long_numpy_instance_data = instance_data[long_numpy]
assert len(long_numpy_instance_data) == len(long_numpy)
bool_numpy = np.random.rand(5) > 0.5
bool_numpy_instance_data = instance_data[bool_numpy]
assert len(bool_numpy_instance_data) == bool_numpy.sum()
# without cat
instance_data.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
bool_numpy = np.random.rand(5) > 0.5
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data.polygons)}`, '
'which has no attribute of `cat`, so it does not '
f'support slice with `bool`')):
bool_numpy_instance_data = instance_data[bool_numpy]
def test_cat(self): def test_cat(self):
instance_data_1 = self.setup_data() instance_data_1 = self.setup_data()
...@@ -97,6 +197,24 @@ class TestInstanceData(TestCase): ...@@ -97,6 +197,24 @@ class TestInstanceData(TestCase):
# Input List length must be greater than 0 # Input List length must be greater than 0
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
InstanceData.cat([]) InstanceData.cat([])
instance_data_2 = instance_data_1.clone()
instance_data_2 = instance_data_2[torch.zeros(5) > 0.5]
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
cat_instance_data = InstanceData.cat([instance_data_1])
assert len(cat_instance_data) == 5
# test custom data cat
instance_data_1.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
instance_data_2 = instance_data_1.clone()
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data_1.polygons)}` '
'which has no attribute of `cat`')):
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
def test_len(self): def test_len(self):
instance_data = self.setup_data() instance_data = self.setup_data()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment