From 9437ebea672af949cb7dc70ec88484089a014670 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Tue, 22 Feb 2022 21:45:32 +0800 Subject: [PATCH] [Feat]: Add abstract data structure (#29) * [WIP] abstract-data-structure * update docs * update * update BaseDataSample * fix comment and coverage 100% * update and add _set_field * update * split into base_data_element and base_data_sample * update * update * update * fix typo --- .../tutorials/abstract_data_interface.md | 28 +- mmengine/__init__.py | 1 + mmengine/data/__init__.py | 6 +- mmengine/data/base_data_element.py | 455 ++++++++++++++ mmengine/data/base_data_sample.py | 563 ++++++++++++++++++ tests/test_data/test_data_element.py | 211 ++++++- tests/test_data/test_data_sample.py | 300 ++++++++-- 7 files changed, 1452 insertions(+), 112 deletions(-) create mode 100644 mmengine/data/base_data_element.py create mode 100644 mmengine/data/base_data_sample.py diff --git a/docs/zh_cn/tutorials/abstract_data_interface.md b/docs/zh_cn/tutorials/abstract_data_interface.md index 6f2e2905..13b93385 100644 --- a/docs/zh_cn/tutorials/abstract_data_interface.md +++ b/docs/zh_cn/tutorials/abstract_data_interface.md @@ -244,8 +244,8 @@ MMEngine ä¸ºæ ·æœ¬æ•°æ®çš„å°è£…æ供了一个基类 `BaseDataSample`,OpenMM `BaseDataSample` 内部ä¾ç„¶åŒºåˆ† metainfo å’Œ data,并且支æŒåƒç±»ä¸€æ ·å¯¹å…¶å±žæ€§è¿›è¡Œè®¾ç½®å’Œè°ƒæ•´ï¼Œä¸ºäº†ä¿è¯ç”¨æˆ·ä½“验的一致性,`BaseDataSample` 的外部接å£ç”¨æ³•å’Œ `BaseDataElement` ä¿æŒä¸€è‡´ã€‚ åŒæ—¶ï¼Œç”±äºŽ `BaseDataSample` 作为基类一般ä¸ä¼šç›´æŽ¥ä½¿ç”¨ï¼Œä¸ºäº†æ–¹ä¾¿ä¸‹æ¸¸ç®—法库快速定义其å类,并对åç±»çš„å±žæ€§è¿›è¡Œè§„çº¦å’Œæ ¡éªŒã€‚ -`BaseDataSample` é¢å¤–æä¾›äº†ä¸€å¥—å†…éƒ¨æŽ¥å£ `_get_field`, `_del_field` å’Œ `_set_field` æ¥ä¾¿åˆ©å®ƒçš„å类快æ·åœ°å®šä¹‰å’Œè§„约 data å±žæ€§çš„å¢žåˆ æ”¹æŸ¥ã€‚ -`_set_field` ä¸ä¼šè¢«å½“作外部接å£ç›´æŽ¥ä½¿ç”¨ï¼Œè€Œæ˜¯è¢«ç”¨æ¥å®šä¹‰å±žæ€§ï¼ˆproperty) çš„ `setter` 并æä¾›åŸºæœ¬çš„ç±»åž‹æ ¡éªŒã€‚ +`BaseDataSample` é¢å¤–æä¾›äº†ä¸€å¥—å†…éƒ¨æŽ¥å£ `get_field`, `del_field` å’Œ `set_field` æ¥ä¾¿åˆ©å®ƒçš„å类快æ·åœ°å®šä¹‰å’Œè§„约 data å±žæ€§çš„å¢žåˆ æ”¹æŸ¥ã€‚ +`set_field` ä¸ä¼šè¢«å½“作外部接å£ç›´æŽ¥ä½¿ç”¨ï¼Œè€Œæ˜¯è¢«ç”¨æ¥å®šä¹‰å±žæ€§ï¼ˆproperty) çš„ `setter` 并æä¾›åŸºæœ¬çš„ç±»åž‹æ ¡éªŒã€‚ 一个简å•ç²—略的实现和用例如下。 @@ -263,15 +263,15 @@ class BaseDataSample(ABC): # 其他功能实现 ... - def _get_field(self, name): + def get_field(self, name): return getattr(self, name) - def _set_field(self, val, name, dtype): + def set_field(self, val, name, dtype): assert isinstance(val, dtype) super().__setattr__(name, val) self._data_fields.add(name) - def _del_field(self, name): + def del_field(self, name): super().__delattr__(name) self._data_fields.remove(name) @@ -284,24 +284,24 @@ class DetDataSample(BaseDataSample): proposals = property( # 定义了 get 方法,通过 name '_proposals' æ¥è®¿é—®å®žé™…维护的å˜é‡ - fget=partial(BaseDataSample._get_field, name='_proposals'), + fget=partial(BaseDataSample.get_field, name='_proposals'), # 定义了 set 方法,将实际维护的å˜é‡è®¾ç½®ä¸º '_proposals',并在设置的时候检查类型是å¦æ˜¯ dtype 定义的类型 InstanceData - fset=partial(BaseDataSample._set_field, name='_proposals', dtype=InstanceData), - fdel=partial(BaseDataSample._del_field, name='_proposals'), + fset=partial(BaseDataSample.set_field, name='_proposals', dtype=InstanceData), + fdel=partial(BaseDataSample.del_field, name='_proposals'), doc='Region proposals of an image' ) gt_instances = property( - fget=partial(BaseDataSample._get_field, name='_gt_instances'), - fset=partial(BaseDataSample._set_field, name='_gt_instances', dtype=InstanceData), - fdel=partial(BaseDataSample._del_field, name='_gt_instances'), + fget=partial(BaseDataSample.get_field, name='_gt_instances'), + fset=partial(BaseDataSample.set_field, name='_gt_instances', dtype=InstanceData), + fdel=partial(BaseDataSample.del_field, name='_gt_instances'), doc='Ground truth instances of an image' ) pred_instances = property( - fget=partial(BaseDataSample._get_field, name='_pred_instances'), - fset=partial(BaseDataSample._set_field, name='_pred_instances', dtype=InstanceData), - fdel=partial(BaseDataSample._del_field, name='_pred_instances'), + fget=partial(BaseDataSample.get_field, name='_pred_instances'), + fset=partial(BaseDataSample.set_field, name='_pred_instances', dtype=InstanceData), + fdel=partial(BaseDataSample.del_field, name='_pred_instances'), doc='Predicted instances of an image' ) ``` diff --git a/mmengine/__init__.py b/mmengine/__init__.py index d389ac84..9e229e25 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -2,6 +2,7 @@ # flake8: noqa from .config import * from .dataset import * +from .data import * from .fileio import * from .registry import * from .utils import * diff --git a/mmengine/data/__init__.py b/mmengine/data/__init__.py index 1c6b205a..a3f1c537 100644 --- a/mmengine/data/__init__.py +++ b/mmengine/data/__init__.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .base_data_element import BaseDataElement +from .base_data_sample import BaseDataSample from .sampler import DefaultSampler, InfiniteSampler -__all__ = ['DefaultSampler', 'InfiniteSampler'] +__all__ = [ + 'BaseDataElement', 'BaseDataSample', 'DefaultSampler', 'InfiniteSampler' +] diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py new file mode 100644 index 00000000..5ab96f6f --- /dev/null +++ b/mmengine/data/base_data_element.py @@ -0,0 +1,455 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Iterator, Optional, Tuple + +import numpy as np +import torch + + +class BaseDataElement: + """A base data structure interface of OpenMMlab. + + Data elements refer to predicted results or ground truth labels on a + task, such as predicted bboxes, instance masks, semantic + segmentation masks, etc. Because groundtruth labels and predicted results + often have similar properties (for example, the predicted bboxes and the + groundtruth bboxes), MMEngine uses the same abstract data interface to + encapsulate predicted results and groundtruth labels, and it is recommended + to use different name conventions to distinguish them, such as using + ``gt_instances`` and ``pred_instances`` to distinguish between labels and + predicted results. Additionally, we distinguish data elements at instance + level, pixel level, and label level. Each of these types has its own + characteristics. Therefore, MMEngine defines the base class + ``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and + ``LabelData`` inheriting from ``BaseDataElement`` to represent different + types of ground truth labels or predictions. + They are used as interfaces between different commopenets. + + + The attributes in ``BaseDataElement`` are divided into two parts, + the ``metainfo`` and the ``data`` respectively. + + - ``metainfo``: Usually contains the + information about the image such as filename, + image_shape, pad_shape, etc. The attributes can be accessed or + modified by dict-like or object-like operations, such as + ``.``(for data access and modification) , ``in``, ``del``, + ``pop(str)``, ``get(str)``, ``metainfo_keys()``, + ``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()``(for + set or change key-value pairs in metainfo). + + - ``data``: Annotations or model predictions are + stored. The attributes can be accessed or modified by + dict-like or object-like operations, such as + ``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``data_keys()``, + ``data_values()``, ``data_items()``. Users can also apply tensor-like + methods to all obj:``torch.Tensor`` in the ``data_fileds``, + such as ``.cuda()``, ``.cpu()``, ``.numpy()``, , ``.to()`` + ``to_tensor()``, ``.detach()``, ``.numpy()`` + + Args: + meta_info (dict, optional): A dict contains the meta information + of single image. such as ``dict(img_shape=(512, 512, 3), + scale_factor=(1, 1, 1, 1))``. Defaults to None. + data (dict, optional): A dict contains annotations of single image or + model predictions. Defaults to None. + + Examples: + >>> from mmengine.data import BaseDataElement + >>> gt_instances = BaseDataElement() + + >>> bboxes = torch.rand((5, 4)) + >>> scores = torch.rand((5,)) + >>> img_id = 0 + >>> img_shape = (800, 1333) + >>> gt_instances = BaseDataElement( + metainfo=dict(img_id=img_id, img_shape=img_shape), + data=dict(bboxes=bboxes, scores=scores)) + >>> gt_instances = BaseDataElement(dict(img_id=img_id, + img_shape=(H, W))) + # new + >>> gt_instances1 = gt_instance.new( + metainfo=dict(img_id=1, img_shape=(640, 640)), + data=dict(bboxes=torch.rand((5, 4)), + scores=torch.rand((5,)))) + >>> gt_instances2 = gt_instances1.new() + + # add and process property + >>> gt_instances = BaseDataElement() + >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) + >>> assert 'img_shape' in gt_instances.metainfo_keys() + >>> assert 'img_shape' in gt_instances + >>> assert 'img_shape' not in gt_instances.data_keys() + >>> assert 'img_shape' in gt_instances.keys() + >>> print(gt_instances.img_shape) + + >>> gt_instances.scores = torch.rand((5,)) + >>> assert 'scores' in gt_instances.data_keys() + >>> assert 'scores' in gt_instances + >>> assert 'scores' in gt_instances.keys() + >>> assert 'scores' not in gt_instances.metainfo_keys() + >>> print(gt_instances.scores) + + >>> gt_instances.bboxes = torch.rand((5, 4)) + >>> assert 'bboxes' in gt_instances.data_keys() + >>> assert 'bboxes' in gt_instances + >>> assert 'bboxes' in gt_instances.keys() + >>> assert 'bboxes' not in gt_instances.metainfo_keys() + >>> print(gt_instances.bboxes) + + # delete and change property + >>> gt_instances = BaseDataElement( + metainfo=dict(img_id=0, img_shape=(640, 640)), + data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))) + >>> gt_instances.img_shape = (1280, 1280) + >>> gt_instances.img_shape # (1280, 1280) + >>> gt_instances.bboxes = gt_instances.bboxes * 2 + >>> gt_instances.get('img_shape', None) # (640, 640) + >>> gt_instances.get('bboxes', None) # 6x4 tensor + >>> del gt_instances.img_shape + >>> del gt_instances.bboxes + >>> assert 'img_shape' not in gt_instances + >>> assert 'bboxes' not in gt_instances + >>> gt_instances.pop('img_shape', None) # None + >>> gt_instances.pop('bboxes', None) # None + + # Tensor-like + >>> cuda_instances = gt_instances.cuda() + >>> cuda_instances = gt_instancess.to('cuda:0') + >>> cpu_instances = cuda_instances.cpu() + >>> cpu_instances = cuda_instances.to('cpu') + >>> fp16_instances = cuda_instances.to( + device=None, dtype=torch.float16, non_blocking=False, copy=False, + memory_format=torch.preserve_format) + >>> cpu_instances = cuda_instances.detach() + >>> np_instances = cpu_instances.numpy() + + # print + >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + >>> instance_data = BaseDataElement(metainfo=img_meta) + >>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) + >>> instance_data.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3]) + >>> print(results) + <BaseDataElement( + META INFORMATION + img_shape: (800, 1196, 3) + pad_shape: (800, 1216, 3) + DATA FIELDS + shape of det_labels: torch.Size([4]) + shape of det_scores: torch.Size([4]) + ) at 0x7f84acd10f90> + """ + + def __init__(self, + metainfo: Optional[dict] = None, + data: Optional[dict] = None) -> None: + + self._metainfo_fields: set = set() + self._data_fields: set = set() + + if metainfo is not None: + self.set_metainfo(metainfo=metainfo) + if data is not None: + self.set_data(data) + + def set_metainfo(self, metainfo: dict) -> None: + """Set or change key-value pairs in ``metainfo_field`` by parameter + ``metainfo``. + + Args: + metainfo (dict): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + """ + assert isinstance( + metainfo, + dict), f'metainfo should be a ``dict`` but got {type(metainfo)}' + meta = copy.deepcopy(metainfo) + for k, v in meta.items(): + if k in self._data_fields: + raise AttributeError(f'`{k}` is used in data,' + 'which is immutable. If you want to' + 'change the key in data, please use' + 'set_data') + self._metainfo_fields.add(k) + self.__dict__[k] = v + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. + """ + assert isinstance(data, + dict), f'meta should be a `dict` but got {data}' + for k, v in data.items(): + self.__setattr__(k, v) + + def new(self, + metainfo: dict = None, + data: dict = None) -> 'BaseDataElement': + """Return a new data element with same type. If ``metainfo`` and + ``data`` are None, the new data element will have same metainfo and + data. If metainfo or data is not None, the new result will overwrite it + with the input value. + + Args: + metainfo (dict, optional): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + Defaults to None. + data (dict, optional): A dict contains annotations of image or + model predictions. Defaults to None. + """ + new_data = self.__class__() + + if metainfo is not None: + new_data.set_metainfo(metainfo) + else: + new_data.set_metainfo(dict(self.metainfo_items())) + if data is not None: + new_data.set_data(data) + else: + new_data.set_data(dict(self.data_items())) + return new_data + + def data_keys(self) -> list: + """ + Returns: + list: Contains all keys in data_fields. + """ + return list(self._data_fields) + + def metainfo_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo_fields. + """ + return list(self._metainfo_fields) + + def data_values(self) -> list: + """ + Returns: + list: Contains all values in data. + """ + return [getattr(self, k) for k in self.data_keys()] + + def metainfo_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo. + """ + return [getattr(self, k) for k in self.metainfo_keys()] + + def keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo and data. + """ + return self.metainfo_keys() + self.data_keys() + + def values(self) -> list: + """ + Returns: + list: Contains all values in metainfo and data. + """ + return self.metainfo_values() + self.data_values() + + def items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``metainfo`` and ``data``. + """ + for k in self.keys(): + yield (k, getattr(self, k)) + + def data_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``data``. + """ + for k in self.data_keys(): + yield (k, getattr(self, k)) + + def metainfo_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``metainfo``. + """ + for k in self.metainfo_keys(): + yield (k, getattr(self, k)) + + def __setattr__(self, name: str, val: Any): + """setattr is only used to set data.""" + if name in ('_metainfo_fields', '_data_fields'): + if not hasattr(self, name): + super().__setattr__(name, val) + else: + raise AttributeError(f'{name} has been used as a ' + 'private attribute, which is immutable. ') + else: + if name in self._metainfo_fields: + raise AttributeError( + f'`{name}` is used in meta information.' + 'if you want to change the key in metainfo, please use' + 'set_metainfo(dict(name=val))') + + self._data_fields.add(name) + super().__setattr__(name, val) + + def __delattr__(self, item: str): + if item in ('_metainfo_fields', '_data_fields'): + raise AttributeError(f'{item} has been used as a ' + 'private attribute, which is immutable. ') + super().__delattr__(item) + if item in self._metainfo_fields: + self._metainfo_fields.remove(item) + elif item in self._data_fields: + self._data_fields.remove(item) + + # dict-like methods + __setitem__ = __setattr__ + __delitem__ = __delattr__ + + def get(self, *args) -> Any: + """get property in data and metainfo as the same as python.""" + assert len(args) < 3, '``get`` get more than 2 arguments' + return self.__dict__.get(*args) + + def pop(self, *args) -> Any: + """pop property in data and metainfo as the same as python.""" + assert len(args) < 3, '``pop`` get more than 2 arguments' + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(args[0]) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(args[0]) + return self.__dict__.pop(*args) + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f'{args[0]} is not contained in metainfo or data') + + def __contains__(self, item: str) -> bool: + return item in self._data_fields or \ + item in self._metainfo_fields + + # Tensor-like methods + def to(self, *args, **kwargs) -> 'BaseDataElement': + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.data_items(): + if hasattr(v, 'to'): + v = v.to(*args, **kwargs) + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if hasattr(v, 'to'): + v = v.to(*args, **kwargs) + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def cpu(self) -> 'BaseDataElement': + """Convert all tensors to CPU in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, torch.Tensor): + v = v.cpu() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, torch.Tensor): + v = v.cpu() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def cuda(self) -> 'BaseDataElement': + """Convert all tensors to GPU in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, torch.Tensor): + v = v.cuda() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, torch.Tensor): + v = v.cuda() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def detach(self) -> 'BaseDataElement': + """Detach all tensors in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, torch.Tensor): + v = v.detach() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, torch.Tensor): + v = v.detach() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def numpy(self) -> 'BaseDataElement': + """Convert all tensor to np.narray in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + def to_tensor(self) -> 'BaseDataElement': + """Convert all np.narray to tensor in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + def __repr__(self) -> str: + repr = '\n META INFORMATION \n' + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + repr += f'shape of {k}: {v.shape} \n' + else: + repr += f'{k}: {v} \n' + repr += '\n DATA FIELDS \n' + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + repr += f'shape of {k}: {v.shape} \n' + else: + repr += f'{k}: {v} \n' + classname = self.__class__.__name__ + return f'<{classname}({repr}\n) at {hex(id(self))}>' diff --git a/mmengine/data/base_data_sample.py b/mmengine/data/base_data_sample.py new file mode 100644 index 00000000..4c6a2607 --- /dev/null +++ b/mmengine/data/base_data_sample.py @@ -0,0 +1,563 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Iterator, Optional, Tuple, Type + +import numpy as np +import torch + +from .base_data_element import BaseDataElement + + +class BaseDataSample: + """A base data structure interface of OpenMMlab. + + A sample data consists of input data (such as an image) and its annotations + and predictions. In general, an image can have multiple types of + annotations and/or predictions at the same time (for example, both + pixel-level semantic segmentation annotations and instance-level detection + bboxes annotations). To facilitate data access of multitask, MMEngine + defines ``BaseDataSample`` as the base class for sample data encapsulation. + **The attributes of ``BaseDataSample`` will be various types of data + elements**, and the codebases in OpenMMLab need to implement their own + xxxDataSample such as ClsDataSample, DetDataSample, SegDataSample based on + ``BaseDataSample`` to encapsulate all relevant data, as a data + interface between dataset, model, visualizer, and evaluator components. + + These attributes in ``BaseDataElement`` are divided into two parts, + the ``metainfo`` and the ``data`` respectively. + + - ``metainfo``: Usually contains the + information about the image such as filename, + image_shape, pad_shape, etc. The attributes can be accessed or + modified by dict-like or object-like operations, such as + ``.``(only for data access) , ``in``, ``del``, ``pop(str)``, + ``get(str)``, ``metainfo_keys()``, ``metainfo_values()``, + ``metainfo_items()``, ``set_metainfo()``(for set or change value + in metainfo). + + - ``data``: Annotations or model predictions are + stored. The attributes can be accessed or modified by + dict-like or object-like operations, such as + ``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``data_keys()``, + ``data_values()``, ``data_items()``. Users can also apply tensor-like + methods to all obj:``torch.Tensor`` in the ``data_fileds``, + such as ``.cuda()``, ``.cpu()``, ``.numpy()``, , ``.to()``, + ``to_tensor()``, ``.detach()``, ``.numpy()`` + + Args: + meta_info (dict, optional): A dict contains the meta information + of a sample. such as ``dict(img_shape=(512, 512, 3), + scale_factor=(1, 1, 1, 1))``. Defaults to None. + data (dict, optional): A dict contains annotations of a sample or + model predictions. Defaults to None. + + Examples: + >>> from mmengine.data import BaseDataElement, BaseDataSample + >>> gt_instances = BaseDataSample() + + >>> bboxes = torch.rand((5, 4)) + >>> scores = torch.rand((5,)) + >>> img_id = 0 + >>> img_shape = (800, 1333) + >>> gt_instances = BaseDataElement( + metainfo=dict(img_id=img_id, img_shape=img_shape), + data=dict(bboxes=bboxes, scores=scores)) + >>> data = dict(gt_instances=gt_instances) + >>> sample = BaseDataSample( + metainfo=dict(img_id=img_id, img_shape=img_shape), + data=data) + >>> sample = BaseDataSample(dict(img_id=img_id, + img_shape=(H, W))) + # new + >>> data1 = dict(bboxes=torch.rand((5, 4)), + scores=torch.rand((5,))) + >>> metainfo1 = dict(img_id=1, img_shape=(640, 640)), + >>> gt_instances1 = BaseDataElement( + metainfo=metainfo1, + data=data1) + >>> sample1 = sample.new( + metainfo=metainfo1 + data=dict(gt_instances1=gt_instances1)), + + >>> gt_instances2 = gt_instances1.new() + + # property add and access + >>> sample = BaseDataSample() + >>> gt_instances = BaseDataElement( + metainfo=dict(img_id=9, img_shape=(100, 100)), + data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,))) + >>> sample.set_metainfo(dict(img_id=9, img_shape=(100, 100)) + >>> assert 'img_shape' in sample.metainfo_keys() + >>> assert 'img_shape' in sample + >>> assert 'img_shape' not in sample.data_keys() + >>> assert 'img_shape' in sample.keys() + >>> print(sample.img_shape) + + >>> gt_instances.gt_instances = gt_instances + >>> assert 'gt_instances' in sample.data_keys() + >>> assert 'gt_instances' in sample + >>> assert 'gt_instances' in sample.keys() + >>> assert 'gt_instances' not in sample.metainfo_keys() + >>> print(sample.gt_instances) + + >>> pred_instances = BaseDataElement( + metainfo=dict(img_id=9, img_shape=(100, 100)), + data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,)) + >>> sample.pred_instances = pred_instances + >>> assert 'pred_instances' in sample.data_keys() + >>> assert 'pred_instances' in sample + >>> assert 'pred_instances' in sample.keys() + >>> assert 'pred_instances' not in sample.metainfo_keys() + >>> print(sample.pred_instances) + + # property delete and change + >>> metainfo=dict(img_id=0, img_shape=(640, 640) + >>> gt_instances = BaseDataElement( + metainfo=metainfo), + data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))) + >>> sample = BaseDataSample(metainfo=metainfo, + data=dict(gt_instances=gt_instances)) + >>> sample.img_shape = (1280, 1280) + >>> sample.img_shape # (1280, 1280) + >>> sample.gt_instances = gt_instances + >>> sample.get('img_shape', None) # (640, 640) + >>> sample.get('gt_instances', None) + >>> del sample.img_shape + >>> del sample.gt_instances + >>> assert 'img_shape' not in sample + >>> assert 'gt_instances' not in sample + >>> sample.pop('img_shape', None) # None + >>> sample.pop('gt_instances', None) # None + + # Tensor-like + >>> cuda_sample = gt_instasamplences.cuda() + >>> cuda_sample = gt_sample.to('cuda:0') + >>> cpu_sample = cuda_sample.cpu() + >>> cpu_sample = cuda_sample.to('cpu') + >>> fp16_sample = cuda_sample.to( + device=None, dtype=torch.float16, non_blocking=False, copy=False, + memory_format=torch.preserve_format) + >>> cpu_sample = cuda_sample.detach() + >>> np_sample = cpu_sample.numpy() + + # print + >>> metainfo = dict(img_shape=(800, 1196, 3)) + >>> gt_instances = BaseDataElement( + metainfo=metainfo, + data=dict(det_labels=torch.LongTensor([0, 1, 2, 3]))) + + >>> data = dict(gt_instances=gt_instances) + >>> sample = BaseDataSample(metainfo=metainfo, data=data) + >>> print(sample) + <BaseDataSample(' + META INFORMATION ' + img_shape: (800, 1196, 3) ' + DATA FIELDS ' + gt_instances:<BaseDataElement(' + META INFORMATION ' + img_shape: (800, 1196, 3) ' + DATA FIELDS ' + shape of det_labels: torch.Size([4]) ' + ) at 0x7f9705daecd0>' + ) at 0x7f981e41c550>' + + # inheritance + >>> class DetDataSample(BaseDataSample): + >>> proposals = property( + >>> fget=partial(BaseDataSample.get_field, name='_proposals'), + >>> fset=partial( + >>> BaseDataSample.set_field, + >>> name='_proposals', + >>> dtype=BaseDataElement), + >>> fdel=partial(BaseDataSample.del_field, name='_proposals'), + >>> doc='Region proposals of an image') + >>> gt_instances = property( + >>> fget=partial(BaseDataSample.get_field, + name='_gt_instances'), + >>> fset=partial( + >>> BaseDataSample.set_field, + >>> name='_gt_instances', + >>> dtype=BaseDataElement), + >>> fdel=partial(BaseDataSample.del_field, + name='_gt_instances'), + >>> doc='Ground truth instances of an image') + >>> pred_instances = property( + >>> fget=partial( + >>> BaseDataSample.get_field, name='_pred_instances'), + >>> fset=partial( + >>> BaseDataSample.set_field, + >>> name='_pred_instances', + >>> dtype=BaseDataElement), + >>> fdel=partial( + >>> BaseDataSample.del_field, name='_pred_instances'), + >>> doc='Predicted instances of an image') + + >>> det_sample = DetDataSample() + >>> proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4)))) + >>> det_sample.proposals = proposals + >>> assert 'proposals' in det_sample + >>> assert det_sample.proposals == proposals + >>> del det_sample.proposals + >>> assert 'proposals' not in det_sample + >>> with self.assertRaises(AssertionError): + det_sample.proposals = torch.rand((5, 4)) + """ + + def __init__(self, + metainfo: Optional[dict] = None, + data: Optional[dict] = None) -> None: + + self._metainfo_fields: set = set() + self._data_fields: set = set() + + if metainfo is not None: + self.set_metainfo(metainfo=metainfo) + if data is not None: + self.set_data(data) + + def set_metainfo(self, metainfo: dict) -> None: + """Set or change key-value pairs in ``metainfo_field`` by parameter + ``metainfo``. + + Args: + metainfo (dict): A dict contains the meta information + of image, such as ``img_shape``, ``scale_factor``, etc. + """ + assert isinstance( + metainfo, dict), f'meta should be a ``dict`` but got {metainfo}' + meta = copy.deepcopy(metainfo) + for k, v in meta.items(): + if k in self._data_fields: + raise AttributeError(f'`{k}` is used in data,' + 'which is immutable. If you want to' + 'change the key in data, please use' + 'set_data') + self.set_field(name=k, value=v, field_type='metainfo', dtype=None) + + def set_data(self, data: dict) -> None: + """Set or change key-value pairs in ``data_field`` by parameter + ``data``. + + Args: + data (dict): A dict contains annotations of image or + model predictions. Defaults to None. + """ + assert isinstance(data, + dict), f'meta should be a ``dict`` but got {data}' + for k, v in data.items(): + self.set_field(name=k, value=v, field_type='data', dtype=None) + + def new(self, + metainfo: Optional[dict] = None, + data: Optional[dict] = None) -> 'BaseDataSample': + """Return a new data element with same type. If ``metainfo`` and + ``data`` are None, the new data element will have same metainfo and + data. If metainfo or data is not None, the new results will overwrite + it with the input value. + + Args: + metainfo (dict, optional): A dict contains the meta information + of image. such as ``img_shape``, ``scale_factor``, etc. + Defaults to None. + data (dict, optional): A dict contains annotations of image or + model predictions. Defaults to None. + """ + new_data = self.__class__() + + if metainfo is not None: + new_data.set_metainfo(metainfo) + else: + new_data.set_metainfo(dict(self.metainfo_items())) + if data is not None: + new_data.set_data(data) + else: + new_data.set_data(dict(self.data_items())) + return new_data + + def data_keys(self) -> list: + """ + Returns: + list: Contains all keys in data_fields. + """ + return list(self._data_fields) + + def metainfo_keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo_fields. + """ + return list(self._metainfo_fields) + + def data_values(self) -> list: + """ + Returns: + list: Contains all values in data. + """ + return [getattr(self, k) for k in self.data_keys()] + + def metainfo_values(self) -> list: + """ + Returns: + list: Contains all values in metainfo. + """ + return [getattr(self, k) for k in self.metainfo_keys()] + + def keys(self) -> list: + """ + Returns: + list: Contains all keys in metainfo and data. + """ + return self.metainfo_keys() + self.data_keys() + + def values(self) -> list: + """ + Returns: + list: Contains all values in metainfo and data. + """ + return self.metainfo_values() + self.data_values() + + def items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``metainfo`` and ``data``. + """ + for k in self.keys(): + yield (k, getattr(self, k)) + + def data_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``data``. + """ + + for k in self.data_keys(): + yield (k, getattr(self, k)) + + def metainfo_items(self) -> Iterator[Tuple[str, Any]]: + """ + Returns: + iterator: an iterator object whose element is (key, value) tuple + pairs for ``metainfo``. + """ + for k in self.metainfo_keys(): + yield (k, getattr(self, k)) + + def __setattr__(self, name: str, value: Any): + """setattr is only used to set data.""" + if name in ('_metainfo_fields', '_data_fields'): + if not hasattr(self, name): + super().__setattr__(name, value) + else: + raise AttributeError( + f'{name} has been used as a ' + f'private attribute, which is immutable. ') + else: + if name in self._metainfo_fields: + raise AttributeError( + f'``{name}`` is used in meta information.' + 'If you want to change the key in metainfo, please use' + 'set_metainfo(dict(name=val))') + + self.set_field( + name=name, value=value, field_type='data', dtype=None) + + def __delattr__(self, item: str): + if item in ('_metainfo_fields', '_data_fields'): + raise AttributeError(f'{item} has been used as a ' + 'private attribute, which is immutable. ') + super().__delattr__(item) + if item in self._metainfo_fields: + self._metainfo_fields.remove(item) + else: + self._data_fields.remove(item) + + # dict-like methods + __setitem__ = __setattr__ + __delitem__ = __delattr__ + + def get(self, *args) -> Any: + """get property in data and metainfo as the same as python.""" + assert len(args) < 3, f'``get`` get more than 2 arguments {args}' + return self.__dict__.get(*args) + + def pop(self, *args) -> Any: + """pop property in data and metainfo as the same as python.""" + assert len(args) < 3, '``pop`` get more than 2 arguments' + name = args[0] + if name in self._metainfo_fields: + self._metainfo_fields.remove(args[0]) + return self.__dict__.pop(*args) + + elif name in self._data_fields: + self._data_fields.remove(args[0]) + return self.__dict__.pop(*args) + + # with default value + elif len(args) == 2: + return args[1] + else: + # don't just use 'self.__dict__.pop(*args)' for only popping key in + # metainfo or data + raise KeyError(f'{args[0]} is not contained in metainfo or data') + + def __contains__(self, item: str) -> bool: + return item in self._data_fields or \ + item in self._metainfo_fields + + def get_field(self, name: str) -> Any: + """Special method for get union field, used as property.getter + functions.""" + return getattr(self, name) + + # It's must to keep the parameters order ``value``, ``name``, for + # ``partial(BaseDataSample.set_field, + # name='_proposals', dtype=BaseDataElement)`` + def set_field(self, + value: Any, + name: str, + dtype: Optional[Type] = None, + field_type: str = 'data') -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ['metainfo', 'data'] + if dtype is not None: + assert isinstance( + value, + dtype), f'{value} should be a {dtype} but got {type(value)}' + + super().__setattr__(name, value) + if field_type == 'metainfo': + self._metainfo_fields.add(name) + else: + self._data_fields.add(name) + + def del_field(self, name: str) -> None: + """Special method for deleting union field, used as property.deleter + functions.""" + self.__delattr__(name) + + # Tensor-like methods + def to(self, *args, **kwargs) -> 'BaseDataSample': + """Apply same name function to all tensors in data_fields.""" + new_data = self.new() + for k, v in self.data_items(): + if hasattr(v, 'to'): + v = v.to(*args, **kwargs) + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if hasattr(v, 'to'): + v = v.to(*args, **kwargs) + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def cpu(self) -> 'BaseDataSample': + """Convert all tensors to CPU in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cpu() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def cuda(self) -> 'BaseDataSample': + """Convert all tensors to GPU in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + data = {k: v} + new_data.set_data(data) + + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.cuda() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def detach(self) -> 'BaseDataSample': + """Detach all tensors in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + # Tensor-like methods + def numpy(self) -> 'BaseDataSample': + """Convert all tensor to np.narray in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + data = {k: v} + new_data.set_data(data) + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.detach().cpu().numpy() + metainfo = {k: v} + new_data.set_metainfo(metainfo) + return new_data + + def to_tensor(self) -> 'BaseDataSample': + """Convert all np.narray to tensor in metainfo and data.""" + new_data = self.new() + for k, v in self.data_items(): + data = {} + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data.update({k: v}) + elif isinstance(v, (BaseDataElement, BaseDataSample)): + v = v.to_tensor() + data.update({k: v}) + new_data.set_data(data) + for k, v in self.metainfo_items(): + data = {} + if isinstance(v, np.ndarray): + v = torch.from_numpy(v) + data.update({k: v}) + elif isinstance(v, (BaseDataElement, BaseDataSample)): + v = v.to_tensor() + data.update({k: v}) + new_data.set_metainfo(data) + return new_data + + def __repr__(self) -> str: + _repr = '\n META INFORMATION \n' + for k, v in self.metainfo_items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + _repr += f'shape of {k}: {v.shape} \n' + elif isinstance(v, (BaseDataElement, BaseDataSample)): + _repr += f'{k}:{repr(v)}\n' + else: + _repr += f'{k}: {v} \n' + _repr += '\n DATA FIELDS \n' + for k, v in self.data_items(): + if isinstance(v, (torch.Tensor, np.ndarray)): + _repr += f'shape of {k}: {v.shape} \n' + elif isinstance(v, (BaseDataElement, BaseDataSample)): + _repr += f'{k}:{repr(v)}\n' + else: + _repr += f'{k}: {v} \n' + classname = self.__class__.__name__ + return f'<{classname}({_repr}\n) at {hex(id(self))}>' diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py index 4bf8c722..7ca7667c 100644 --- a/tests/test_data/test_data_element.py +++ b/tests/test_data/test_data_element.py @@ -18,6 +18,13 @@ class TestBaseDataElement(TestCase): data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) return metainfo, data + def is_equal(self, x, y): + assert type(x) == type(y) + if isinstance(x, (int, float, str, list, tuple, dict, set)): + return x == y + elif isinstance(x, (torch.Tensor, np.ndarray)): + return (x == y).all() + def check_key_value(self, instances, metainfo=None, data=None): # check the existence of keys in metainfo, data, and instances if metainfo: @@ -26,22 +33,22 @@ class TestBaseDataElement(TestCase): assert k in instances.keys() assert k in instances.metainfo_keys() assert k not in instances.data_keys() - assert instances.get(k) == v - assert getattr(instances, k) == v + assert self.is_equal(instances.get(k), v) + assert self.is_equal(getattr(instances, k), v) if data: for k, v in data.items(): assert k in instances assert k in instances.keys() assert k not in instances.metainfo_keys() assert k in instances.data_keys() - assert instances.get(k) == v - assert getattr(instances, k) == v + assert self.is_equal(instances.get(k), v) + assert self.is_equal(getattr(instances, k), v) def check_data_device(self, instances, device): - assert instances.device == device + # assert instances.device == device for v in instances.data_values(): if isinstance(v, torch.Tensor): - assert v.device == device + assert v.device == torch.device(device) def check_data_dtype(self, instances, dtype): for v in instances.data_values(): @@ -83,8 +90,12 @@ class TestBaseDataElement(TestCase): # test new() with no arguments new_instances = instances.new() assert type(new_instances) == type(instances) - assert id(new_instances.bboxes) != id(instances.bboxes) - assert id(new_instances.bboxes) != id(data['bboxes']) + # After deepcopy, the address of new data'element will be same as + # origin, but when change new data' element will not effect the origin + # element and will have new address + _, data = self.setup_data() + new_instances.set_data(data) + assert not self.is_equal(new_instances.bboxes, instances.bboxes) self.check_key_value(new_instances, metainfo, data) # test new() with arguments @@ -92,8 +103,13 @@ class TestBaseDataElement(TestCase): new_instances = instances.new(metainfo=metainfo, data=data) assert type(new_instances) == type(instances) assert id(new_instances.bboxes) != id(instances.bboxes) + _, new_data = self.setup_data() + new_instances.set_data(new_data) assert id(new_instances.bboxes) != id(data['bboxes']) - self.check_key_value(new_instances, metainfo, data) + self.check_key_value(new_instances, metainfo, new_data) + + metainfo, data = self.setup_data() + new_instances = instances.new(metainfo=metainfo) def test_set_metainfo(self): metainfo, _ = self.setup_data() @@ -107,6 +123,16 @@ class TestBaseDataElement(TestCase): instances.set_metainfo(new_metainfo) self.check_key_value(instances, metainfo=new_metainfo) + # test have the same key in data + _, data = self.setup_data() + instances = BaseDataElement(data=data) + _, data = self.setup_data() + with self.assertRaises(AttributeError): + instances.set_metainfo(data) + + with self.assertRaises(AssertionError): + instances.set_metainfo(123) + def test_set_data(self): metainfo, data = self.setup_data() instances = BaseDataElement() @@ -122,9 +148,18 @@ class TestBaseDataElement(TestCase): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo, data) - with self.assertRaises(AssertionError): + with self.assertRaises(AttributeError): instances.img_shape = metainfo['img_shape'] + # test set '_metainfo_fields' or '_data_fields' + with self.assertRaises(AttributeError): + instances._metainfo_fields = 1 + with self.assertRaises(AttributeError): + instances._data_fields = 1 + + with self.assertRaises(AssertionError): + instances.set_data(123) + def test_delete_modify(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo, data) @@ -137,29 +172,60 @@ class TestBaseDataElement(TestCase): instances.set_metainfo(new_metainfo) self.check_key_value(instances, new_metainfo, new_data) - assert instances.bboxes != data['bboxes'] - assert instances.scores != data['scores'] - assert instances.img_id != metainfo['img_id'] - assert instances.img_shape != metainfo['img_shape'] + assert not self.is_equal(instances.bboxes, data['bboxes']) + assert not self.is_equal(instances.scores, data['scores']) + assert not self.is_equal(instances.img_id, metainfo['img_id']) + assert not self.is_equal(instances.img_shape, metainfo['img_shape']) del instances.bboxes - assert instances.pop('scores', None) == new_data['scores'] + del instances.img_id + assert not self.is_equal(instances.pop('scores', None), data['scores']) with self.assertRaises(AttributeError): del instances.scores assert 'bboxes' not in instances assert 'scores' not in instances + assert 'img_id' not in instances assert instances.pop('bboxes', None) is None + # test pop not exist key without default + with self.assertRaises(KeyError): + instances.pop('bboxes') assert instances.pop('scores', 'abcdef') == 'abcdef' + assert instances.pop('img_id', None) is None + # test pop not exist key without default + with self.assertRaises(KeyError): + instances.pop('img_id') + assert instances.pop('img_shape') == new_metainfo['img_shape'] + + # test del '_metainfo_fields' or '_data_fields' + with self.assertRaises(AttributeError): + del instances._metainfo_fields + with self.assertRaises(AttributeError): + del instances._data_fields + @pytest.mark.skipif( not torch.cuda.is_available(), reason='GPU is required!') def test_cuda(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, data=data) cuda_instances = instances.cuda() - self.check_data_device(instances, 'cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + # here we further test to convert from cuda to cpu + cpu_instances = cuda_instances.cpu() + self.check_data_device(cpu_instances, 'cpu') + del cuda_instances + + cuda_instances = instances.to('cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + _, data = self.setup_data() + instances = BaseDataElement(metainfo=data) + + cuda_instances = instances.cuda() + self.check_data_device(cuda_instances, 'cuda:0') # here we further test to convert from cuda to cpu cpu_instances = cuda_instances.cpu() @@ -175,9 +241,18 @@ class TestBaseDataElement(TestCase): self.check_data_device(instances, 'cpu') cpu_instances = instances.cpu() - assert cpu_instances.device == 'cpu' - assert cpu_instances.bboxes.device == 'cpu' - assert cpu_instances.scores.device == 'cpu' + # assert cpu_instances.device == 'cpu' + assert cpu_instances.bboxes.device == torch.device('cpu') + assert cpu_instances.scores.device == torch.device('cpu') + + _, data = self.setup_data() + instances = BaseDataElement(metainfo=data) + self.check_data_device(instances, 'cpu') + + cpu_instances = instances.cpu() + # assert cpu_instances.device == 'cpu' + assert cpu_instances.bboxes.device == torch.device('cpu') + assert cpu_instances.scores.device == torch.device('cpu') def test_numpy_tensor(self): metainfo, data = self.setup_data() @@ -186,19 +261,89 @@ class TestBaseDataElement(TestCase): np_instances = instances.numpy() self.check_data_dtype(np_instances, np.ndarray) - tensor_instances = instances.to_tensor() + tensor_instances = np_instances.to_tensor() self.check_data_dtype(tensor_instances, torch.Tensor) + _, data = self.setup_data() + instances = BaseDataElement(metainfo=data) + + np_instances = instances.numpy() + self.check_data_dtype(np_instances, np.ndarray) + + tensor_instances = np_instances.to_tensor() + self.check_data_dtype(tensor_instances, torch.Tensor) + + def check_requires_grad(self, instances): + for v in instances.data_values(): + if isinstance(v, torch.Tensor): + assert v.requires_grad is False + + def test_detach(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + instances.detach() + self.check_requires_grad(instances) + + _, data = self.setup_data() + instances = BaseDataElement(metainfo=data) + instances.detach() + self.check_requires_grad(instances) + def test_repr(self): - metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) - instances = BaseDataElement(metainfo=metainfo) - instances.det_labels = torch.LongTensor([0, 1, 2, 3]) - instances.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3]) - assert repr(instances) == ('<BaseDataElement(\n' - ' META INFORMATION\n' - 'img_shape: (800, 1196, 3)\n' - 'pad_shape: (800, 1216, 3)\n' - ' DATA FIELDS\n' - 'shape of det_labels: torch.Size([4])\n' - 'shape of det_scores: torch.Size([4])\n' - ') at 0x7f84acd10f90>') + metainfo = dict(img_shape=(800, 1196, 3)) + data = dict(det_labels=torch.LongTensor([0, 1, 2, 3])) + instances = BaseDataElement(metainfo=metainfo, data=data) + address = hex(id(instances)) + assert repr(instances) == (f'<BaseDataElement(' + f'\n META INFORMATION \n' + f'img_shape: (800, 1196, 3) \n' + f'\n DATA FIELDS \n' + f'shape of det_labels: torch.Size([4]) \n' + f'\n) at {address}>') + metainfo = dict(img_shape=(800, 1196, 3)) + data = dict(det_labels=torch.LongTensor([0, 1, 2, 3])) + instances = BaseDataElement(data=metainfo, metainfo=data) + address = hex(id(instances)) + assert repr(instances) == (f'<BaseDataElement(' + f'\n META INFORMATION \n' + f'shape of det_labels: torch.Size([4]) \n' + f'\n DATA FIELDS \n' + f'img_shape: (800, 1196, 3) \n' + f'\n) at {address}>') + + def test_values(self): + # test_metainfo_values + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + assert len(instances.metainfo_values()) == len(metainfo.values()) + # test_values + assert len( + instances.values()) == len(metainfo.values()) + len(data.values()) + + # test_data_values + assert len(instances.data_values()) == len(data.values()) + + def test_keys(self): + # test_metainfo_keys + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + assert len(instances.metainfo_keys()) == len(metainfo.keys()) + + # test_keys + assert len(instances.keys()) == len(data.keys()) + len(metainfo.keys()) + + # test_data_keys + assert len(instances.data_keys()) == len(data.keys()) + + def test_items(self): + # test_metainfo_items + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo, data) + assert len(dict(instances.metainfo_items())) == len( + dict(metainfo.items())) + # test_items + assert len(dict(instances.items())) == len(dict( + metainfo.items())) + len(dict(data.items())) + + # test_data_items + assert len(dict(instances.data_items())) == len(dict(data.items())) diff --git a/tests/test_data/test_data_sample.py b/tests/test_data/test_data_sample.py index 0d16073b..cc940073 100644 --- a/tests/test_data/test_data_sample.py +++ b/tests/test_data/test_data_sample.py @@ -23,6 +23,14 @@ class TestBaseDataSample(TestCase): data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data + def is_equal(self, x, y): + assert type(x) == type(y) + if isinstance( + x, (int, float, str, list, tuple, dict, set, BaseDataElement)): + return x == y + elif isinstance(x, (torch.Tensor, np.ndarray)): + return (x == y).all() + def check_key_value(self, instances, metainfo=None, data=None): # check the existence of keys in metainfo, data, and instances if metainfo: @@ -31,32 +39,39 @@ class TestBaseDataSample(TestCase): assert k in instances.keys() assert k in instances.metainfo_keys() assert k not in instances.data_keys() - assert instances.get(k) == v - assert getattr(instances, k) == v + assert self.is_equal(instances.get(k), v) + assert self.is_equal(getattr(instances, k), v) if data: for k, v in data.items(): assert k in instances assert k in instances.keys() assert k not in instances.metainfo_keys() assert k in instances.data_keys() - assert instances.get(k) == v - assert getattr(instances, k) == v + assert self.is_equal(instances.get(k), v) + assert self.is_equal(getattr(instances, k), v) def check_data_device(self, instances, device): - assert instances.device == device + # assert instances.device == device for v in instances.data_values(): if isinstance(v, torch.Tensor): - assert v.device == device - elif isinstance(v, BaseDataElement): + assert v.device == torch.device(device) + elif isinstance(v, (BaseDataSample, BaseDataElement)): self.check_data_device(v, device) def check_data_dtype(self, instances, dtype): for v in instances.data_values(): if isinstance(v, (torch.Tensor, np.ndarray)): assert isinstance(v, dtype) - if isinstance(v, BaseDataElement): + if isinstance(v, (BaseDataSample, BaseDataElement)): self.check_data_dtype(v, dtype) + def check_requires_grad(self, instances): + for v in instances.data_values(): + if isinstance(v, torch.Tensor): + assert v.requires_grad is False + if isinstance(v, (BaseDataSample, BaseDataElement)): + self.check_requires_grad(v) + def test_init(self): # initialization with no data and metainfo metainfo, data = self.setup_data() @@ -92,17 +107,27 @@ class TestBaseDataSample(TestCase): # test new() with no arguments new_instances = instances.new() assert type(new_instances) == type(instances) - assert id(new_instances.data) != id(instances.data) - assert id(new_instances.bboxes) != id(data) + # After deepcopy, the address of new data'element will be same as + # origin, but when change new data' element will not effect the origin + # element and will have new address + _, data = self.setup_data() + new_instances.set_data(data) + assert not self.is_equal(new_instances.gt_instances, + instances.gt_instances) self.check_key_value(new_instances, metainfo, data) # test new() with arguments metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo, data=data) assert type(new_instances) == type(instances) - assert id(new_instances.data) != id(instances.data) - assert id(new_instances.data) != id(data) - self.check_key_value(new_instances, metainfo, data) + assert id(new_instances.gt_instances) != id(instances.gt_instances) + _, new_data = self.setup_data() + new_instances.set_data(new_data) + assert id(new_instances.gt_instances) != id(data['gt_instances']) + self.check_key_value(new_instances, metainfo, new_data) + + metainfo, data = self.setup_data() + new_instances = instances.new(metainfo=metainfo) def test_set_metainfo(self): metainfo, _ = self.setup_data() @@ -116,6 +141,17 @@ class TestBaseDataSample(TestCase): instances.set_metainfo(new_metainfo) self.check_key_value(instances, metainfo=new_metainfo) + # test have the same key in data + # TODO + _, data = self.setup_data() + instances = BaseDataSample(data=data) + _, data = self.setup_data() + with self.assertRaises(AttributeError): + instances.set_metainfo(data) + + with self.assertRaises(AssertionError): + instances.set_metainfo(123) + def test_set_data(self): metainfo, data = self.setup_data() instances = BaseDataSample() @@ -129,6 +165,21 @@ class TestBaseDataSample(TestCase): instances.img_id = metainfo['img_id'] self.check_key_value(instances, data=metainfo) + # test can not set metainfo with `.` + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + with self.assertRaises(AttributeError): + instances.img_shape = metainfo['img_shape'] + + # test set '_metainfo_fields' or '_data_fields' + with self.assertRaises(AttributeError): + instances._metainfo_fields = 1 + with self.assertRaises(AttributeError): + instances._data_fields = 1 + + with self.assertRaises(AssertionError): + instances.set_data(123) + def test_delete_modify(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) @@ -141,25 +192,40 @@ class TestBaseDataSample(TestCase): instances.set_metainfo(new_metainfo) self.check_key_value(instances, new_metainfo, new_data) - assert instances.gt_instances != data['gt_instances'] - assert instances.pred_instances != data['pred_instances'] - assert instances.img_id != metainfo['img_id'] - assert instances.img_shape != metainfo['img_shape'] + assert not self.is_equal(instances.gt_instances, data['gt_instances']) + assert not self.is_equal(instances.pred_instances, + data['pred_instances']) + assert not self.is_equal(instances.img_id, metainfo['img_id']) + assert not self.is_equal(instances.img_shape, metainfo['img_shape']) del instances.gt_instances - assert instances.pop('pred_instances', - None) == new_data['pred_instances'] - - # pred_instances has been deleted, - # instances does not have the pred_instances + del instances.img_id + assert not self.is_equal( + instances.pop('pred_instances', None), data['pred_instances']) with self.assertRaises(AttributeError): del instances.pred_instances assert 'gt_instances' not in instances assert 'pred_instances' not in instances + assert 'img_id' not in instances assert instances.pop('gt_instances', None) is None + # test pop not exist key without default + with self.assertRaises(KeyError): + instances.pop('gt_instances') assert instances.pop('pred_instances', 'abcdef') == 'abcdef' + assert instances.pop('img_id', None) is None + # test pop not exist key without default + with self.assertRaises(KeyError): + instances.pop('img_id') + assert instances.pop('img_shape') == new_metainfo['img_shape'] + + # test del '_metainfo_fields' or '_data_fields' + with self.assertRaises(AttributeError): + del instances._metainfo_fields + with self.assertRaises(AttributeError): + del instances._data_fields + @pytest.mark.skipif( not torch.cuda.is_available(), reason='GPU is required!') def test_cuda(self): @@ -167,7 +233,21 @@ class TestBaseDataSample(TestCase): instances = BaseDataSample(metainfo, data) cuda_instances = instances.cuda() - self.check_data_device(instances, 'cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + # here we further test to convert from cuda to cpu + cpu_instances = cuda_instances.cpu() + self.check_data_device(cpu_instances, 'cpu') + del cuda_instances + + cuda_instances = instances.to('cuda:0') + self.check_data_device(cuda_instances, 'cuda:0') + + _, data = self.setup_data() + instances = BaseDataSample(metainfo=data) + + cuda_instances = instances.cuda() + self.check_data_device(cuda_instances, 'cuda:0') # here we further test to convert from cuda to cpu cpu_instances = cuda_instances.cpu() @@ -183,49 +263,105 @@ class TestBaseDataSample(TestCase): self.check_data_device(instances, 'cpu') cpu_instances = instances.cpu() - assert cpu_instances.device == 'cpu' - assert cpu_instances.bboxes.device == 'cpu' - assert cpu_instances.scores.device == 'cpu' + # assert cpu_instances.device == 'cpu' + self.check_data_device(cpu_instances, 'cpu') + + _, data = self.setup_data() + instances = BaseDataSample(metainfo=data) + self.check_data_device(instances, 'cpu') + + cpu_instances = instances.cpu() + # assert cpu_instances.device == 'cpu' + self.check_data_device(cpu_instances, 'cpu') def test_numpy_tensor(self): metainfo, data = self.setup_data() + data.update(bboxes=torch.rand((5, 4))) instances = BaseDataSample(metainfo, data) + np_instances = instances.numpy() + self.check_data_dtype(np_instances, np.ndarray) + + tensor_instances = np_instances.to_tensor() + self.check_data_dtype(tensor_instances, torch.Tensor) + + _, data = self.setup_data() + data.update(bboxes=torch.rand((5, 4))) + instances = BaseDataSample(metainfo=data) np_instances = instances.numpy() self.check_data_dtype(np_instances, np.ndarray) - tensor_instances = instances.to_tensor() + tensor_instances = np_instances.to_tensor() self.check_data_dtype(tensor_instances, torch.Tensor) + def test_detach(self): + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + instances.detach() + self.check_requires_grad(instances) + + _, data = self.setup_data() + instances = BaseDataSample(metainfo=data) + instances.detach() + self.check_requires_grad(instances) + def test_repr(self): - metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) + metainfo = dict(img_shape=(800, 1196, 3)) gt_instances = BaseDataElement( - data=dict( - det_labels=torch.LongTensor([0, 1, 2, 3], - det_scores=torch.Tensor( - [0.01, 0.1, 0.2, 0.3])))) + metainfo=metainfo, + data=dict(det_labels=torch.LongTensor([0, 1, 2, 3]))) + data = dict(gt_instances=gt_instances) - instances = BaseDataSample(metainfo=metainfo, data=data) - assert repr(instances) == ('<BaseDataSample(\n' - ' META INFORMATION\n' - 'img_shape: (800, 1196, 3)\n' - 'pad_shape: (800, 1216, 3)\n' - ' DATA FIELDS\n' - '\tgt_instances: <BaseDataElement(\n' - '\t META INFORMATION\n' - '\timg_shape: (800, 1196, 3)\n' - '\tpad_shape: (800, 1216, 3)\n' - '\t DATA FIELDS\n' - '\tshape of det_labels: torch.Size([4])\n' - '\tshape of det_scores: torch.Size([4])\n' - '\t) at 0x7f84acd10f90>' - ') at 0x7f84acd10f90>') + sample = BaseDataSample(metainfo=metainfo, data=data) + address = hex(id(sample)) + address_gt_instances = hex(id(sample.gt_instances)) + assert repr(sample) == (f'<BaseDataSample(' + f'\n META INFORMATION \n' + f'img_shape: (800, 1196, 3) \n' + f'\n DATA FIELDS \n' + f'gt_instances:<BaseDataElement(' + f'\n META INFORMATION \n' + f'img_shape: (800, 1196, 3) \n' + f'\n DATA FIELDS \n' + f'shape of det_labels: torch.Size([4]) \n' + f'\n) at {address_gt_instances}>\n' + f'\n) at {address}>') + + sample = BaseDataSample(data=metainfo, metainfo=data) + address = hex(id(sample)) + address_gt_instances = hex(id(sample.gt_instances)) + assert repr(sample) == (f'<BaseDataSample(' + f'\n META INFORMATION \n' + f'gt_instances:<BaseDataElement(' + f'\n META INFORMATION \n' + f'img_shape: (800, 1196, 3) \n' + f'\n DATA FIELDS \n' + f'shape of det_labels: torch.Size([4]) \n' + f'\n) at {address_gt_instances}>\n' + f'\n DATA FIELDS \n' + f'img_shape: (800, 1196, 3) \n' + f'\n) at {address}>') + metainfo = dict(bboxes=torch.rand((5, 4))) + sample = BaseDataSample(metainfo=metainfo) + address = hex(id(sample)) + assert repr(sample) == (f'<BaseDataSample(' + f'\n META INFORMATION \n' + f'shape of bboxes: torch.Size([5, 4]) \n' + f'\n DATA FIELDS \n' + f'\n) at {address}>') + sample = BaseDataSample(data=metainfo) + address = hex(id(sample)) + assert repr(sample) == (f'<BaseDataSample(' + f'\n META INFORMATION \n' + f'\n DATA FIELDS \n' + f'shape of bboxes: torch.Size([5, 4]) \n' + f'\n) at {address}>') def test_set_get_fields(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo) for key, value in data.items(): - instances._set_field(value, key, BaseDataElement) + instances.set_field(name=key, value=value, dtype=BaseDataElement) self.check_key_value(instances, data=data) # test type check @@ -233,19 +369,20 @@ class TestBaseDataSample(TestCase): instances = BaseDataSample() for key, value in data.items(): with self.assertRaises(AssertionError): - instances._set_field(value, key, BaseDataSample) + instances.set_field( + name=key, value=value, dtype=BaseDataSample) def test_del_field(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo) for key, value in data.items(): - instances._set_field(value, key, BaseDataElement) - instances._del_field('gt_instances') - instances._del_field('pred_instances') + instances.set_field(value=value, name=key, dtype=BaseDataElement) + instances.del_field('gt_instances') + instances.del_field('pred_instances') # gt_instance has been deleted, instances does not have the gt_instance with self.assertRaises(AttributeError): - instances._del_field('gt_instances') + instances.del_field('gt_instances') assert 'gt_instances' not in instances assert 'pred_instances' not in instances @@ -253,30 +390,28 @@ class TestBaseDataSample(TestCase): class DetDataSample(BaseDataSample): proposals = property( - fget=partial(BaseDataSample._get_field, name='_proposals'), + fget=partial(BaseDataSample.get_field, name='_proposals'), fset=partial( - BaseDataSample._set_field, + BaseDataSample.set_field, name='_proposals', dtype=BaseDataElement), - fdel=partial(BaseDataSample._del_field, name='_proposals'), + fdel=partial(BaseDataSample.del_field, name='_proposals'), doc='Region proposals of an image') gt_instances = property( - fget=partial(BaseDataSample._get_field, name='_gt_instances'), + fget=partial(BaseDataSample.get_field, name='_gt_instances'), fset=partial( - BaseDataSample._set_field, + BaseDataSample.set_field, name='_gt_instances', dtype=BaseDataElement), - fdel=partial(BaseDataSample._del_field, name='_gt_instances'), + fdel=partial(BaseDataSample.del_field, name='_gt_instances'), doc='Ground truth instances of an image') pred_instances = property( - fget=partial( - BaseDataSample._get_field, name='_pred_instances'), + fget=partial(BaseDataSample.get_field, name='_pred_instances'), fset=partial( - BaseDataSample._set_field, + BaseDataSample.set_field, name='_pred_instances', dtype=BaseDataElement), - fdel=partial( - BaseDataSample._del_field, name='_pred_instances'), + fdel=partial(BaseDataSample.del_field, name='_pred_instances'), doc='Predicted instances of an image') det_sample = DetDataSample() @@ -296,3 +431,40 @@ class TestBaseDataSample(TestCase): # test the data whether meet the requirements with self.assertRaises(AssertionError): det_sample.proposals = torch.rand((5, 4)) + + def test_values(self): + # test_metainfo_values + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + assert len(instances.metainfo_values()) == len(metainfo.values()) + # test_values + assert len( + instances.values()) == len(metainfo.values()) + len(data.values()) + + # test_data_values + assert len(instances.data_values()) == len(data.values()) + + def test_keys(self): + # test_metainfo_keys + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + assert len(instances.metainfo_keys()) == len(metainfo.keys()) + + # test_keys + assert len(instances.keys()) == len(data.keys()) + len(metainfo.keys()) + + # test_data_keys + assert len(instances.data_keys()) == len(data.keys()) + + def test_items(self): + # test_metainfo_items + metainfo, data = self.setup_data() + instances = BaseDataSample(metainfo, data) + assert len(dict(instances.metainfo_items())) == len( + dict(metainfo.items())) + # test_items + assert len(dict(instances.items())) == len(dict( + metainfo.items())) + len(dict(data.items())) + + # test_data_items + assert len(dict(instances.data_items())) == len(dict(data.items())) -- GitLab