Skip to content
Snippets Groups Projects
Unverified Commit 92674eda authored by Xiangxu-0103's avatar Xiangxu-0103 Committed by GitHub
Browse files

[Docs] Update docstring of BaseDataElement (#836)


* update doc

* update doc

* Update mmengine/structures/base_data_element.py

* Update mmengine/structures/base_data_element.py

* Apply suggestions from code review

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 2eede1b9
No related branches found
No related tags found
No related merge requests found
...@@ -46,27 +46,27 @@ class BaseDataElement: ...@@ -46,27 +46,27 @@ class BaseDataElement:
The attributes in ``BaseDataElement`` are divided into two parts, The attributes in ``BaseDataElement`` are divided into two parts,
the ``metainfo`` and the ``data`` respectively. the ``metainfo`` and the ``data`` respectively.
- ``metainfo``: Usually contains the - ``metainfo``: Usually contains the
information about the image such as filename, information about the image such as filename,
image_shape, pad_shape, etc. The attributes can be accessed or image_shape, pad_shape, etc. The attributes can be accessed or
modified by dict-like or object-like operations, such as modified by dict-like or object-like operations, such as
``.``(for data access and modification) , ``in``, ``del``, ``.`` (for data access and modification), ``in``, ``del``,
``pop(str)``, ``get(str)``, ``metainfo_keys()``, ``pop(str)``, ``get(str)``, ``metainfo_keys()``,
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()``(for ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for
set or change key-value pairs in metainfo). set or change key-value pairs in metainfo).
- ``data``: Annotations or model predictions are - ``data``: Annotations or model predictions are
stored. The attributes can be accessed or modified by stored. The attributes can be accessed or modified by
dict-like or object-like operations, such as dict-like or object-like operations, such as
``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``keys()``, ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
``values()``, ``items()``. Users can also apply tensor-like ``values()``, ``items()``. Users can also apply tensor-like
methods to all obj:``torch.Tensor`` in the ``data_fileds``, methods to all :obj:`torch.Tensor` in the ``data_fields``,
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, , ``.to()`` such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``,
``to_tensor()``, ``.detach()``. ``to_tensor()``, ``.detach()``.
Args: Args:
metainfo (dict, optional): A dict contains the meta information metainfo (dict, optional): A dict contains the meta information
of single image. such as ``dict(img_shape=(512, 512, 3), of single image, such as ``dict(img_shape=(512, 512, 3),
scale_factor=(1, 1, 1, 1))``. Defaults to None. scale_factor=(1, 1, 1, 1))``. Defaults to None.
kwargs (dict, optional): A dict contains annotations of single image or kwargs (dict, optional): A dict contains annotations of single image or
model predictions. Defaults to None. model predictions. Defaults to None.
...@@ -82,46 +82,52 @@ class BaseDataElement: ...@@ -82,46 +82,52 @@ class BaseDataElement:
... metainfo=dict(img_id=img_id, img_shape=img_shape), ... metainfo=dict(img_id=img_id, img_shape=img_shape),
... bboxes=bboxes, scores=scores) ... bboxes=bboxes, scores=scores)
>>> gt_instances = BaseDataElement( >>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=img_id, ... metainfo=dict(img_id=img_id, img_shape=(640, 640)))
... img_shape=(H, W)))
>>> # new >>> # new
>>> gt_instances1 = gt_instance.new( >>> gt_instances1 = gt_instances.new(
... metainfo=dict(img_id=1, img_shape=(640, 640)), ... metainfo=dict(img_id=1, img_shape=(640, 640)),
... bboxes=torch.rand((5, 4)), ... bboxes=torch.rand((5, 4)),
... scores=torch.rand((5,))) ... scores=torch.rand((5,)))
>>> gt_instances2 = gt_instances1.new() >>> gt_instances2 = gt_instances1.new()
>>> # add and process property >>> # add and process property
>>> gt_instances = BaseDataElement() >>> gt_instances = BaseDataElement()
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)))
>>> assert 'img_shape' in gt_instances.metainfo_keys() >>> assert 'img_shape' in gt_instances.metainfo_keys()
>>> assert 'img_shape' in gt_instances >>> assert 'img_shape' in gt_instances
>>> assert 'img_shape' not in gt_instances.keys() >>> assert 'img_shape' not in gt_instances.keys()
>>> assert 'img_shape' in gt_instances.all_keys() >>> assert 'img_shape' in gt_instances.all_keys()
>>> print(gt_instances.img_shape) >>> print(gt_instances.img_shape)
(100, 100)
>>> gt_instances.scores = torch.rand((5,)) >>> gt_instances.scores = torch.rand((5,))
>>> assert 'scores' in gt_instances.keys() >>> assert 'scores' in gt_instances.keys()
>>> assert 'scores' in gt_instances >>> assert 'scores' in gt_instances
>>> assert 'scores' in gt_instances.all_keys() >>> assert 'scores' in gt_instances.all_keys()
>>> assert 'scores' not in gt_instances.metainfo_keys() >>> assert 'scores' not in gt_instances.metainfo_keys()
>>> print(gt_instances.scores) >>> print(gt_instances.scores)
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876])
>>> gt_instances.bboxes = torch.rand((5, 4)) >>> gt_instances.bboxes = torch.rand((5, 4))
>>> assert 'bboxes' in gt_instances.keys() >>> assert 'bboxes' in gt_instances.keys()
>>> assert 'bboxes' in gt_instances >>> assert 'bboxes' in gt_instances
>>> assert 'bboxes' in gt_instances.all_keys() >>> assert 'bboxes' in gt_instances.all_keys()
>>> assert 'bboxes' not in gt_instances.metainfo_keys() >>> assert 'bboxes' not in gt_instances.metainfo_keys()
>>> print(gt_instances.bboxes) >>> print(gt_instances.bboxes)
tensor([[0.0900, 0.0424, 0.1755, 0.4469],
[0.8648, 0.0592, 0.3484, 0.0913],
[0.5808, 0.1909, 0.6165, 0.7088],
[0.5490, 0.4209, 0.9416, 0.2374],
[0.3652, 0.1218, 0.8805, 0.7523]])
>>> # delete and change property >>> # delete and change property
>>> gt_instances = BaseDataElement( >>> gt_instances = BaseDataElement(
... metainfo=dict(img_id=0, img_shape=(640, 640)), ... metainfo=dict(img_id=0, img_shape=(640, 640)),
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))
>>> gt_instances.img_shape = (1280, 1280) >>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280)))
>>> gt_instances.img_shape # (1280, 1280) >>> gt_instances.img_shape # (1280, 1280)
>>> gt_instances.bboxes = gt_instances.bboxes * 2 >>> gt_instances.bboxes = gt_instances.bboxes * 2
>>> gt_instances.get('img_shape', None) # (640, 640) >>> gt_instances.get('img_shape', None) # (1280, 1280)
>>> gt_instances.get('bboxes', None) # 6x4 tensor >>> gt_instances.get('bboxes', None) # 6x4 tensor
>>> del gt_instances.img_shape >>> del gt_instances.img_shape
>>> del gt_instances.bboxes >>> del gt_instances.bboxes
>>> assert 'img_shape' not in gt_instances >>> assert 'img_shape' not in gt_instances
...@@ -131,19 +137,19 @@ class BaseDataElement: ...@@ -131,19 +137,19 @@ class BaseDataElement:
>>> # Tensor-like >>> # Tensor-like
>>> cuda_instances = gt_instances.cuda() >>> cuda_instances = gt_instances.cuda()
>>> cuda_instances = gt_instancess.to('cuda:0') >>> cuda_instances = gt_instances.to('cuda:0')
>>> cpu_instances = cuda_instances.cpu() >>> cpu_instances = cuda_instances.cpu()
>>> cpu_instances = cuda_instances.to('cpu') >>> cpu_instances = cuda_instances.to('cpu')
>>> fp16_instances = cuda_instances.to( >>> fp16_instances = cuda_instances.to(
... device=None, dtype=torch.float16, non_blocking=False, copy=False, ... device=None, dtype=torch.float16, non_blocking=False,
... memory_format=torch.preserve_format) ... copy=False, memory_format=torch.preserve_format)
>>> cpu_instances = cuda_instances.detach() >>> cpu_instances = cuda_instances.detach()
>>> np_instances = cpu_instances.numpy() >>> np_instances = cpu_instances.numpy()
>>> # print >>> # print
>>> metainfo = dict(img_shape=(800, 1196, 3)) >>> metainfo = dict(img_shape=(800, 1196, 3))
>>> gt_instances = BaseDataElement( >>> gt_instances = BaseDataElement(
>>> metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) ... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3]))
>>> sample = BaseDataElement(metainfo=metainfo, >>> sample = BaseDataElement(metainfo=metainfo,
... gt_instances=gt_instances) ... gt_instances=gt_instances)
>>> print(sample) >>> print(sample)
...@@ -185,7 +191,7 @@ class BaseDataElement: ...@@ -185,7 +191,7 @@ class BaseDataElement:
... return self._pred_instances ... return self._pred_instances
... @pred_instances.setter ... @pred_instances.setter
... def pred_instances(self, value): ... def pred_instances(self, value):
... self.set_field(value,'_pred_instances', ... self.set_field(value, '_pred_instances',
... dtype=BaseDataElement) ... dtype=BaseDataElement)
... @pred_instances.deleter ... @pred_instances.deleter
... def pred_instances(self): ... def pred_instances(self):
...@@ -235,7 +241,7 @@ class BaseDataElement: ...@@ -235,7 +241,7 @@ class BaseDataElement:
model predictions. model predictions.
""" """
assert isinstance(data, assert isinstance(data,
dict), f'meta should be a `dict` but got {data}' dict), f'data should be a `dict` but got {data}'
for k, v in data.items(): for k, v in data.items():
# Use `setattr()` rather than `self.set_field` to allow `set_data` # Use `setattr()` rather than `self.set_field` to allow `set_data`
# to set property method. # to set property method.
...@@ -247,7 +253,7 @@ class BaseDataElement: ...@@ -247,7 +253,7 @@ class BaseDataElement:
Args: Args:
instance (BaseDataElement): Another BaseDataElement object for instance (BaseDataElement): Another BaseDataElement object for
update the current object. update the current object.
""" """
assert isinstance( assert isinstance(
instance, BaseDataElement instance, BaseDataElement
...@@ -272,7 +278,7 @@ class BaseDataElement: ...@@ -272,7 +278,7 @@ class BaseDataElement:
model predictions. model predictions.
Returns: Returns:
BaseDataElement: a new data element with same type. BaseDataElement: A new data element with same type.
""" """
new_data = self.__class__() new_data = self.__class__()
...@@ -290,7 +296,7 @@ class BaseDataElement: ...@@ -290,7 +296,7 @@ class BaseDataElement:
"""Deep copy the current data element. """Deep copy the current data element.
Returns: Returns:
BaseDataElement: the copy of current data element. BaseDataElement: The copy of current data element.
""" """
clone_data = self.__class__() clone_data = self.__class__()
clone_data.set_metainfo(dict(self.metainfo_items())) clone_data.set_metainfo(dict(self.metainfo_items()))
...@@ -342,7 +348,7 @@ class BaseDataElement: ...@@ -342,7 +348,7 @@ class BaseDataElement:
def all_items(self) -> Iterator[Tuple[str, Any]]: def all_items(self) -> Iterator[Tuple[str, Any]]:
""" """
Returns: Returns:
iterator: an iterator object whose element is (key, value) tuple iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo`` and ``data``. pairs for ``metainfo`` and ``data``.
""" """
for k in self.all_keys(): for k in self.all_keys():
...@@ -351,7 +357,7 @@ class BaseDataElement: ...@@ -351,7 +357,7 @@ class BaseDataElement:
def items(self) -> Iterator[Tuple[str, Any]]: def items(self) -> Iterator[Tuple[str, Any]]:
""" """
Returns: Returns:
iterator: an iterator object whose element is (key, value) tuple iterator: An iterator object whose element is (key, value) tuple
pairs for ``data``. pairs for ``data``.
""" """
for k in self.keys(): for k in self.keys():
...@@ -360,7 +366,7 @@ class BaseDataElement: ...@@ -360,7 +366,7 @@ class BaseDataElement:
def metainfo_items(self) -> Iterator[Tuple[str, Any]]: def metainfo_items(self) -> Iterator[Tuple[str, Any]]:
""" """
Returns: Returns:
iterator: an iterator object whose element is (key, value) tuple iterator: An iterator object whose element is (key, value) tuple
pairs for ``metainfo``. pairs for ``metainfo``.
""" """
for k in self.metainfo_keys(): for k in self.metainfo_keys():
...@@ -378,20 +384,20 @@ class BaseDataElement: ...@@ -378,20 +384,20 @@ class BaseDataElement:
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
raise AttributeError(f'{name} has been used as a ' raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable. ') 'private attribute, which is immutable.')
else: else:
self.set_field( self.set_field(
name=name, value=value, field_type='data', dtype=None) name=name, value=value, field_type='data', dtype=None)
def __delattr__(self, item: str): def __delattr__(self, item: str):
"""delete the item in dataelement. """Delete the item in dataelement.
Args: Args:
item (str): The key to delete. item (str): The key to delete.
""" """
if item in ('_metainfo_fields', '_data_fields'): if item in ('_metainfo_fields', '_data_fields'):
raise AttributeError(f'{item} has been used as a ' raise AttributeError(f'{item} has been used as a '
'private attribute, which is immutable. ') 'private attribute, which is immutable.')
super().__delattr__(item) super().__delattr__(item)
if item in self._metainfo_fields: if item in self._metainfo_fields:
self._metainfo_fields.remove(item) self._metainfo_fields.remove(item)
...@@ -402,13 +408,13 @@ class BaseDataElement: ...@@ -402,13 +408,13 @@ class BaseDataElement:
__delitem__ = __delattr__ __delitem__ = __delattr__
def get(self, key, default=None) -> Any: def get(self, key, default=None) -> Any:
"""get property in data and metainfo as the same as python.""" """Get property in data and metainfo as the same as python."""
# Use `getattr()` rather than `self.__dict__.get()` to allow getting # Use `getattr()` rather than `self.__dict__.get()` to allow getting
# properties. # properties.
return getattr(self, key, default) return getattr(self, key, default)
def pop(self, *args) -> Any: def pop(self, *args) -> Any:
"""pop property in data and metainfo as the same as python.""" """Pop property in data and metainfo as the same as python."""
assert len(args) < 3, '``pop`` get more than 2 arguments' assert len(args) < 3, '``pop`` get more than 2 arguments'
name = args[0] name = args[0]
if name in self._metainfo_fields: if name in self._metainfo_fields:
...@@ -459,7 +465,7 @@ class BaseDataElement: ...@@ -459,7 +465,7 @@ class BaseDataElement:
raise AttributeError( raise AttributeError(
f'Cannot set {name} to be a field of data ' f'Cannot set {name} to be a field of data '
f'because {name} is already a metainfo field') f'because {name} is already a metainfo field')
# The name only added to `data_fields`` when it is not the # The name only added to `data_fields` when it is not the
# attribute related to property(methods decorated by @property). # attribute related to property(methods decorated by @property).
if not isinstance( if not isinstance(
getattr(type(self), getattr(type(self),
...@@ -513,7 +519,7 @@ class BaseDataElement: ...@@ -513,7 +519,7 @@ class BaseDataElement:
# Tensor-like methods # Tensor-like methods
def numpy(self) -> 'BaseDataElement': def numpy(self) -> 'BaseDataElement':
"""Convert all tensor to np.narray in data.""" """Convert all tensors to np.ndarray in data."""
new_data = self.new() new_data = self.new()
for k, v in self.items(): for k, v in self.items():
if isinstance(v, (torch.Tensor, BaseDataElement)): if isinstance(v, (torch.Tensor, BaseDataElement)):
...@@ -523,7 +529,7 @@ class BaseDataElement: ...@@ -523,7 +529,7 @@ class BaseDataElement:
return new_data return new_data
def to_tensor(self) -> 'BaseDataElement': def to_tensor(self) -> 'BaseDataElement':
"""Convert all np.narray to tensor in data.""" """Convert all np.ndarray to tensor in data."""
new_data = self.new() new_data = self.new()
for k, v in self.items(): for k, v in self.items():
data = {} data = {}
...@@ -544,7 +550,7 @@ class BaseDataElement: ...@@ -544,7 +550,7 @@ class BaseDataElement:
} }
def __repr__(self) -> str: def __repr__(self) -> str:
"""represent the object.""" """Represent the object."""
def _addindent(s_: str, num_spaces: int) -> str: def _addindent(s_: str, num_spaces: int) -> str:
"""This func is modified from `pytorch` https://github.com/pytorch/ """This func is modified from `pytorch` https://github.com/pytorch/
...@@ -569,13 +575,13 @@ class BaseDataElement: ...@@ -569,13 +575,13 @@ class BaseDataElement:
return s # type: ignore return s # type: ignore
def dump(obj: Any) -> str: def dump(obj: Any) -> str:
"""represent the object. """Represent the object.
Args: Args:
obj (Any): The obj to represent. obj (Any): The obj to represent.
Returns: Returns:
str: The represented str . str: The represented str.
""" """
_repr = '' _repr = ''
if isinstance(obj, dict): if isinstance(obj, dict):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment