Skip to content
Snippets Groups Projects
Unverified Commit 591024e5 authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by GitHub
Browse files

[Docs] Update docstring of structures (#840)

* Update docstring of `structures`

* update docs

* add `import torch` to `examples`
parent e1f61252
No related branches found
No related tags found
No related merge requests found
...@@ -72,6 +72,7 @@ class BaseDataElement: ...@@ -72,6 +72,7 @@ class BaseDataElement:
model predictions. Defaults to None. model predictions. Defaults to None.
Examples: Examples:
>>> import torch
>>> from mmengine.structures import BaseDataElement >>> from mmengine.structures import BaseDataElement
>>> gt_instances = BaseDataElement() >>> gt_instances = BaseDataElement()
>>> bboxes = torch.rand((5, 4)) >>> bboxes = torch.rand((5, 4))
......
...@@ -22,7 +22,7 @@ class InstanceData(BaseDataElement): ...@@ -22,7 +22,7 @@ class InstanceData(BaseDataElement):
should have the same length. This design refer to should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
in data field can be base data structure such as `torch.tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
Examples: Examples:
...@@ -34,7 +34,7 @@ class InstanceData(BaseDataElement): ...@@ -34,7 +34,7 @@ class InstanceData(BaseDataElement):
... def __len__(self): ... def __len__(self):
... return len(self.tmp) ... return len(self.tmp)
... def __getitem__(self, item): ... def __getitem__(self, item):
... if type(item) == int: ... if isinstance(item, int):
... if item >= len(self) or item < -len(self): # type:ignore ... if item >= len(self) or item < -len(self): # type:ignore
... raise IndexError(f'Index {item} out of range!') ... raise IndexError(f'Index {item} out of range!')
... else: ... else:
...@@ -54,6 +54,7 @@ class InstanceData(BaseDataElement): ...@@ -54,6 +54,7 @@ class InstanceData(BaseDataElement):
... return str(self.tmp) ... return str(self.tmp)
>>> from mmengine.structures import InstanceData >>> from mmengine.structures import InstanceData
>>> import numpy as np >>> import numpy as np
>>> import torch
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = InstanceData(metainfo=img_meta) >>> instance_data = InstanceData(metainfo=img_meta)
>>> 'img_shape' in instance_data >>> 'img_shape' in instance_data
...@@ -67,41 +68,39 @@ class InstanceData(BaseDataElement): ...@@ -67,41 +68,39 @@ class InstanceData(BaseDataElement):
>>> print(instance_data) >>> print(instance_data)
<InstanceData( <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1196, 3) img_shape: (800, 1196, 3)
img_shape: (800, 1216, 3) pad_shape: (800, 1216, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([2, 3]) det_labels: tensor([2, 3])
det_scores: tensor([0.8, 0.7000]) det_scores: tensor([0.8000, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]]) [0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7fb492de6280> ) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] >>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores >>> sorted_results.det_scores
tensor([0.7000, 0.8000]) tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75]) >>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData( <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3) img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([2]) det_labels: tensor([2])
masks: [[11, 21, 31, 41]]
det_scores: tensor([0.8000]) det_scores: tensor([0.8000])
bboxes: tensor([[0.9308, 0.4000, 0.6077, 0.5554]]) bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
polygons: [[1, 2, 3, 4]] polygons: [[1, 2, 3, 4]]
) at 0x7f64ecf0ec40> ) at 0x7f64ecf0ec40>
>>> print(instance_data[instance_data.det_scores > 1]) >>> print(instance_data[instance_data.det_scores > 1])
<InstanceData( <InstanceData(
META INFORMATION META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3) img_shape: (800, 1196, 3)
pad_shape: (800, 1216, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([], dtype=torch.int64) det_labels: tensor([], dtype=torch.int64)
masks: []
det_scores: tensor([]) det_scores: tensor([])
bboxes: tensor([], size=(0, 4)) bboxes: tensor([], size=(0, 4))
polygons: [[]] polygons: []
) at 0x7f660a6a7f70> ) at 0x7f660a6a7f70>
>>> print(instance_data.cat([instance_data, instance_data])) >>> print(instance_data.cat([instance_data, instance_data]))
<InstanceData( <InstanceData(
...@@ -110,44 +109,39 @@ class InstanceData(BaseDataElement): ...@@ -110,44 +109,39 @@ class InstanceData(BaseDataElement):
pad_shape: (800, 1216, 3) pad_shape: (800, 1216, 3)
DATA FIELDS DATA FIELDS
det_labels: tensor([2, 3, 2, 3]) det_labels: tensor([2, 3, 2, 3])
bboxes: tensor([[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810],
[0.7404, 0.6332, 0.1684, 0.9961],
[0.2837, 0.8112, 0.5416, 0.2810]])
data:
polygons: [[1, 2, 3, 4], [5, 6, 7, 8],
[1, 2, 3, 4], [5, 6, 7, 8]]
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
masks: [[11, 21, 31, 41], [51, 61, 71, 81], bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[11, 21, 31, 41], [51, 61, 71, 81]] [0.8101, 0.3105, 0.5123, 0.6263],
[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
) at 0x7f203542feb0> ) at 0x7f203542feb0>
""" """
def __setattr__(self, name: str, value: Sized): def __setattr__(self, name: str, value: Sized):
"""setattr is only used to set data. """setattr is only used to set data.
the value must have the attribute of `__len__` and have the same length The value must have the attribute of `__len__` and have the same length
of instancedata of `InstanceData`.
""" """
if name in ('_metainfo_fields', '_data_fields'): if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name): if not hasattr(self, name):
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError( raise AttributeError(f'{name} has been used as a '
f'{name} has been used as a ' 'private attribute, which is immutable.')
f'private attribute, which is immutable. ')
else: else:
assert isinstance(value, assert isinstance(value,
Sized), 'value must contain `_len__` attribute' Sized), 'value must contain `__len__` attribute'
if len(self) > 0: if len(self) > 0:
assert len(value) == len(self), f'the length of ' \ assert len(value) == len(self), 'The length of ' \
f'values {len(value)} is ' \ f'values {len(value)} is ' \
f'not consistent with' \ 'not consistent with ' \
f' the length of this ' \ 'the length of this ' \
f':obj:`InstanceData` ' \ ':obj:`InstanceData` ' \
f'{len(self)} ' f'{len(self)}'
super().__setattr__(name, value) super().__setattr__(name, value)
__setitem__ = __setattr__ __setitem__ = __setattr__
...@@ -155,12 +149,12 @@ class InstanceData(BaseDataElement): ...@@ -155,12 +149,12 @@ class InstanceData(BaseDataElement):
def __getitem__(self, item: IndexType) -> 'InstanceData': def __getitem__(self, item: IndexType) -> 'InstanceData':
""" """
Args: Args:
item (str, obj:`slice`, item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`): :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
get the corresponding values according to item. Get the corresponding values according to item.
Returns: Returns:
obj:`InstanceData`: Corresponding values. :obj:`InstanceData`: Corresponding values.
""" """
if isinstance(item, list): if isinstance(item, list):
item = np.array(item) item = np.array(item)
...@@ -178,7 +172,7 @@ class InstanceData(BaseDataElement): ...@@ -178,7 +172,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, str): if isinstance(item, str):
return getattr(self, item) return getattr(self, item)
if type(item) == int: if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!') raise IndexError(f'Index {item} out of range!')
else: else:
...@@ -190,14 +184,14 @@ class InstanceData(BaseDataElement): ...@@ -190,14 +184,14 @@ class InstanceData(BaseDataElement):
assert item.dim() == 1, 'Only support to get the' \ assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.' ' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)): if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \ assert len(item) == len(self), 'The shape of the ' \
f' input(BoolTensor)) ' \ 'input(BoolTensor) ' \
f'{len(item)} ' \ f'{len(item)} ' \
f' does not match the shape ' \ 'does not match the shape ' \
f'of the indexed tensor ' \ 'of the indexed tensor ' \
f'in results_filed ' \ 'in results_field ' \
f'{len(self)} at ' \ f'{len(self)} at ' \
f'first dimension. ' 'first dimension.'
for k, v in self.items(): for k, v in self.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
...@@ -207,7 +201,7 @@ class InstanceData(BaseDataElement): ...@@ -207,7 +201,7 @@ class InstanceData(BaseDataElement):
elif isinstance( elif isinstance(
v, (str, list, tuple)) or (hasattr(v, '__getitem__') v, (str, list, tuple)) or (hasattr(v, '__getitem__')
and hasattr(v, 'cat')): and hasattr(v, 'cat')):
# convert to indexes from boolTensor # convert to indexes from BoolTensor
if isinstance(item, if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)): (torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view( indexes = torch.nonzero(item).view(
...@@ -232,7 +226,7 @@ class InstanceData(BaseDataElement): ...@@ -232,7 +226,7 @@ class InstanceData(BaseDataElement):
raise ValueError( raise ValueError(
f'The type of `{k}` is `{type(v)}`, which has no ' f'The type of `{k}` is `{type(v)}`, which has no '
'attribute of `cat`, so it does not ' 'attribute of `cat`, so it does not '
f'support slice with `bool`') 'support slice with `bool`')
else: else:
# item is a slice # item is a slice
...@@ -252,7 +246,7 @@ class InstanceData(BaseDataElement): ...@@ -252,7 +246,7 @@ class InstanceData(BaseDataElement):
of :obj:`InstanceData`. of :obj:`InstanceData`.
Returns: Returns:
obj:`InstanceData` :obj:`InstanceData`
""" """
assert all( assert all(
isinstance(results, InstanceData) for results in instances_list) isinstance(results, InstanceData) for results in instances_list)
...@@ -272,7 +266,7 @@ class InstanceData(BaseDataElement): ...@@ -272,7 +266,7 @@ class InstanceData(BaseDataElement):
'cause the cat operation ' \ 'cause the cat operation ' \
'to fail. Please make sure all ' \ 'to fail. Please make sure all ' \
'elements in `instances_list` ' \ 'elements in `instances_list` ' \
'have the exact same key ' 'have the exact same key.'
new_data = instances_list[0].__class__( new_data = instances_list[0].__class__(
metainfo=instances_list[0].metainfo) metainfo=instances_list[0].metainfo)
...@@ -297,7 +291,7 @@ class InstanceData(BaseDataElement): ...@@ -297,7 +291,7 @@ class InstanceData(BaseDataElement):
return new_data # type:ignore return new_data # type:ignore
def __len__(self) -> int: def __len__(self) -> int:
"""int: the length of InstanceData""" """int: The length of InstanceData."""
if len(self._data_fields) > 0: if len(self._data_fields) > 0:
return len(self.values()[0]) return len(self.values()[0])
else: else:
......
...@@ -16,7 +16,7 @@ class LabelData(BaseDataElement): ...@@ -16,7 +16,7 @@ class LabelData(BaseDataElement):
onehot (torch.Tensor, optional): The one-hot input. The format onehot (torch.Tensor, optional): The one-hot input. The format
of input must be one-hot. of input must be one-hot.
Return: Returns:
torch.Tensor: The converted results. torch.Tensor: The converted results.
""" """
assert isinstance(onehot, torch.Tensor) assert isinstance(onehot, torch.Tensor)
...@@ -36,7 +36,7 @@ class LabelData(BaseDataElement): ...@@ -36,7 +36,7 @@ class LabelData(BaseDataElement):
of item must be label-format. of item must be label-format.
num_classes (int): The number of classes. num_classes (int): The number of classes.
Return: Returns:
torch.Tensor: The converted results. torch.Tensor: The converted results.
""" """
assert isinstance(label, torch.Tensor) assert isinstance(label, torch.Tensor)
......
...@@ -26,12 +26,12 @@ class PixelData(BaseDataElement): ...@@ -26,12 +26,12 @@ class PixelData(BaseDataElement):
>>> pixel_data = PixelData(metainfo=metainfo, >>> pixel_data = PixelData(metainfo=metainfo,
... image=image, ... image=image,
... featmap=featmap) ... featmap=featmap)
>>> print(pixel_data) >>> print(pixel_data.shape)
>>> (20, 40) (20, 40)
>>> # slice >>> # slice
>>> slice_data = pixel_data[10:20, 20:40] >>> slice_data = pixel_data[10:20, 20:40]
>>> assert slice_data.shape == (10, 10) >>> assert slice_data.shape == (10, 20)
>>> slice_data = pixel_data[10, 20] >>> slice_data = pixel_data[10, 20]
>>> assert slice_data.shape == (1, 1) >>> assert slice_data.shape == (1, 1)
...@@ -47,41 +47,40 @@ class PixelData(BaseDataElement): ...@@ -47,41 +47,40 @@ class PixelData(BaseDataElement):
"""Set attributes of ``PixelData``. """Set attributes of ``PixelData``.
If the dimension of value is 2 and its shape meet the demand, it If the dimension of value is 2 and its shape meet the demand, it
will automatically expend its channel-dimension. will automatically expand its channel-dimension.
Args: Args:
name (str): The key to access the value, stored in `PixelData`. name (str): The key to access the value, stored in `PixelData`.
value (Union[torch.Tensor, np.ndarray]): The value to store in. value (Union[torch.Tensor, np.ndarray]): The value to store in.
The type of value must be `torch.Tensor` or `np.ndarray`, The type of value must be `torch.Tensor` or `np.ndarray`,
and its shape must meet the requirements of `PixelData`. and its shape must meet the requirements of `PixelData`.
""" """
if name in ('_metainfo_fields', '_data_fields'): if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name): if not hasattr(self, name):
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError( raise AttributeError(f'{name} has been used as a '
f'{name} has been used as a ' 'private attribute, which is immutable.')
f'private attribute, which is immutable. ')
else: else:
assert isinstance(value, (torch.Tensor, np.ndarray)), \ assert isinstance(value, (torch.Tensor, np.ndarray)), \
f'Can set {type(value)}, only support' \ f'Can not set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray)}' f' {(torch.Tensor, np.ndarray)}'
if self.shape: if self.shape:
assert tuple(value.shape[-2:]) == self.shape, ( assert tuple(value.shape[-2:]) == self.shape, (
f'the height and width of ' 'The height and width of '
f'values {tuple(value.shape[-2:])} is ' f'values {tuple(value.shape[-2:])} is '
f'not consistent with' 'not consistent with '
f' the length of this ' 'the shape of this '
f':obj:`PixelData` ' ':obj:`PixelData` '
f'{self.shape} ') f'{self.shape}')
assert value.ndim in [ assert value.ndim in [
2, 3 2, 3
], f'The dim of value must be 2 or 3, but got {value.ndim}' ], f'The dim of value must be 2 or 3, but got {value.ndim}'
if value.ndim == 2: if value.ndim == 2:
value = value[None] value = value[None]
warnings.warn(f'The shape of value will convert from ' warnings.warn('The shape of value will convert from '
f'{value.shape[-2:]} to {value.shape}') f'{value.shape[-2:]} to {value.shape}')
super().__setattr__(name, value) super().__setattr__(name, value)
...@@ -89,17 +88,17 @@ class PixelData(BaseDataElement): ...@@ -89,17 +88,17 @@ class PixelData(BaseDataElement):
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData': def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
""" """
Args: Args:
item (Sequence[Union[int, slice]]): get the corresponding values item (Sequence[Union[int, slice]]): Get the corresponding values
according to item. according to item.
Returns: Returns:
obj:`PixelData`: Corresponding values. :obj:`PixelData`: Corresponding values.
""" """
new_data = self.__class__(metainfo=self.metainfo) new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, tuple): if isinstance(item, tuple):
assert len(item) == 2, 'Only support slice height and width' assert len(item) == 2, 'Only support to slice height and width'
tmp_item: List[slice] = list() tmp_item: List[slice] = list()
for index, single_item in enumerate(item[::-1]): for index, single_item in enumerate(item[::-1]):
if isinstance(single_item, int): if isinstance(single_item, int):
......
...@@ -23,7 +23,7 @@ class TmpObject: ...@@ -23,7 +23,7 @@ class TmpObject:
return len(self.tmp) return len(self.tmp)
def __getitem__(self, item): def __getitem__(self, item):
if type(item) == int: if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!') raise IndexError(f'Index {item} out of range!')
else: else:
...@@ -58,13 +58,13 @@ class TmpObjectWithoutCat: ...@@ -58,13 +58,13 @@ class TmpObjectWithoutCat:
return len(self.tmp) return len(self.tmp)
def __getitem__(self, item): def __getitem__(self, item):
if type(item) == int: if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!') raise IndexError(f'Index {item} out of range!')
else: else:
# keep the dimension # keep the dimension
item = slice(item, None, len(self)) item = slice(item, None, len(self))
return TmpObject(self.tmp[item]) return TmpObjectWithoutCat(self.tmp[item])
def __repr__(self): def __repr__(self):
return str(self.tmp) return str(self.tmp)
...@@ -131,18 +131,18 @@ class TestInstanceData(TestCase): ...@@ -131,18 +131,18 @@ class TestInstanceData(TestCase):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
instance_data[item] instance_data[item]
# when input is a bool tensor, The shape of # when input is a bool tensor, the shape of
# the input at index 0 should equal to # the input at index 0 should equal to
# the value length in instance_data_field # the value length in instance_data_field
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
instance_data[item.bool()] instance_data[item.bool()]
# test Longtensor # test LongTensor
long_tensor = torch.randint(5, (2, )) long_tensor = torch.randint(5, (2, ))
long_index_instance_data = instance_data[long_tensor] long_index_instance_data = instance_data[long_tensor]
assert len(long_index_instance_data) == len(long_tensor) assert len(long_index_instance_data) == len(long_tensor)
# test bool tensor # test BoolTensor
bool_tensor = torch.rand(5) > 0.5 bool_tensor = torch.rand(5) > 0.5
bool_index_instance_data = instance_data[bool_tensor] bool_index_instance_data = instance_data[bool_tensor]
assert len(bool_index_instance_data) == bool_tensor.sum() assert len(bool_index_instance_data) == bool_tensor.sum()
...@@ -155,7 +155,7 @@ class TestInstanceData(TestCase): ...@@ -155,7 +155,7 @@ class TestInstanceData(TestCase):
list_index_instance_data = instance_data[list_index] list_index_instance_data = instance_data[list_index]
assert len(list_index_instance_data) == len(list_index) assert len(list_index_instance_data) == len(list_index)
# text list bool # test list bool
list_bool = [True, False, True, False, False] list_bool = [True, False, True, False, False]
list_bool_instance_data = instance_data[list_bool] list_bool_instance_data = instance_data[list_bool]
assert len(list_bool_instance_data) == 2 assert len(list_bool_instance_data) == 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment