From 7e246b6f65872b514a89a8bb06862777464f8af3 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Thu, 31 Mar 2022 18:21:45 +0800 Subject: [PATCH] [Enhancement] refactor base data element (#143) * [Enhancement] refactor base data elment * fix comment * fix comment * fix pop not existing key without error --- docs/zh_cn/tutorials/evaluator.md | 8 +- mmengine/data/__init__.py | 5 +- mmengine/data/base_data_element.py | 311 ++++++---- mmengine/data/base_data_sample.py | 558 ------------------ mmengine/data/utils.py | 8 +- mmengine/evaluator/base.py | 10 +- mmengine/evaluator/composed_evaluator.py | 10 +- mmengine/hooks/checkpoint_hook.py | 6 +- mmengine/hooks/empty_cache_hook.py | 8 +- mmengine/hooks/hook.py | 28 +- mmengine/hooks/iter_timer_hook.py | 10 +- mmengine/hooks/logger_hook.py | 6 +- mmengine/hooks/naive_visualization_hook.py | 10 +- mmengine/hooks/optimizer_hook.py | 6 +- mmengine/hooks/param_scheduler_hook.py | 6 +- mmengine/runner/loops.py | 22 +- mmengine/visualization/visualizer.py | 10 +- mmengine/visualization/writer.py | 62 +- tests/test_data/test_data_element.py | 302 ++++++---- tests/test_data/test_data_sample.py | 479 --------------- tests/test_evaluator/test_base_evaluator.py | 10 +- .../test_naive_visualization_hook.py | 22 +- tests/test_visualizer/test_visualizer.py | 6 +- 23 files changed, 501 insertions(+), 1402 deletions(-) delete mode 100644 mmengine/data/base_data_sample.py delete mode 100644 tests/test_data/test_data_sample.py diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md index 695073d4..1a09f181 100644 --- a/docs/zh_cn/tutorials/evaluator.md +++ b/docs/zh_cn/tutorials/evaluator.md @@ -111,16 +111,16 @@ class Accuracy(BaseEvaluator): default_prefix = 'ACC' - def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]], - predictions: Sequence[BaseDataSample]): + def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + predictions: Sequence[BaseDataElement]): """Process one batch of data and predictions. The processed Results should be stored in `self.results`, which will be used to computed the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data from the dataloader. - predictions (Sequence[BaseDataSample]): A batch of outputs from + predictions (Sequence[BaseDataElement]): A batch of outputs from the model. """ diff --git a/mmengine/data/__init__.py b/mmengine/data/__init__.py index 396b07b9..b867465c 100644 --- a/mmengine/data/__init__.py +++ b/mmengine/data/__init__.py @@ -1,10 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_data_element import BaseDataElement -from .base_data_sample import BaseDataSample from .sampler import DefaultSampler, InfiniteSampler from .utils import pseudo_collate, worker_init_fn __all__ = [ - 'BaseDataElement', 'BaseDataSample', 'DefaultSampler', 'InfiniteSampler', - 'worker_init_fn', 'pseudo_collate' + 'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn', + 'pseudo_collate' ] diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py index 609ce538..47e8d715 100644 --- a/mmengine/data/base_data_element.py +++ b/mmengine/data/base_data_element.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Any, Iterator, Optional, Tuple +from typing import Any, Iterator, Optional, Tuple, Type import numpy as np import torch @@ -40,17 +40,17 @@ class BaseDataElement: - ``data``: Annotations or model predictions are stored. The attributes can be accessed or modified by dict-like or object-like operations, such as - ``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``data_keys()``, - ``data_values()``, ``data_items()``. Users can also apply tensor-like + ``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``keys()``, + ``values()``, ``items()``. Users can also apply tensor-like methods to all obj:``torch.Tensor`` in the ``data_fileds``, such as ``.cuda()``, ``.cpu()``, ``.numpy()``, , ``.to()`` - ``to_tensor()``, ``.detach()``, ``.numpy()`` + ``to_tensor()``, ``.detach()``. Args: - meta_info (dict, optional): A dict contains the meta information + metainfo (dict, optional): A dict contains the meta information of single image. such as ``dict(img_shape=(512, 512, 3), scale_factor=(1, 1, 1, 1))``. Defaults to None. - data (dict, optional): A dict contains annotations of single image or + kwargs (dict, optional): A dict contains annotations of single image or model predictions. Defaults to None. Examples: @@ -62,9 +62,10 @@ class BaseDataElement: >>> img_shape = (800, 1333) >>> gt_instances = BaseDataElement( ... metainfo=dict(img_id=img_id, img_shape=img_shape), - ... data=dict(bboxes=bboxes, scores=scores)) - >>> gt_instances = BaseDataElement(dict(img_id=img_id, - ... img_shape=(H, W))) + ... bboxes=bboxes, scores=scores) + >>> gt_instances = BaseDataElement( + ... metainfo=dict(img_id=img_id, + ... img_shape=(H, W))) >>> # new >>> gt_instances1 = gt_instance.new( @@ -78,26 +79,26 @@ class BaseDataElement: >>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100)) >>> assert 'img_shape' in gt_instances.metainfo_keys() >>> assert 'img_shape' in gt_instances - >>> assert 'img_shape' not in gt_instances.data_keys() - >>> assert 'img_shape' in gt_instances.keys() + >>> assert 'img_shape' not in gt_instances.keys() + >>> assert 'img_shape' in gt_instances.all_keys() >>> print(gt_instances.img_shape) >>> gt_instances.scores = torch.rand((5,)) - >>> assert 'scores' in gt_instances.data_keys() - >>> assert 'scores' in gt_instances >>> assert 'scores' in gt_instances.keys() + >>> assert 'scores' in gt_instances + >>> assert 'scores' in gt_instances.all_keys() >>> assert 'scores' not in gt_instances.metainfo_keys() >>> print(gt_instances.scores) >>> gt_instances.bboxes = torch.rand((5, 4)) - >>> assert 'bboxes' in gt_instances.data_keys() - >>> assert 'bboxes' in gt_instances >>> assert 'bboxes' in gt_instances.keys() + >>> assert 'bboxes' in gt_instances + >>> assert 'bboxes' in gt_instances.all_keys() >>> assert 'bboxes' not in gt_instances.metainfo_keys() >>> print(gt_instances.bboxes) >>> # delete and change property >>> gt_instances = BaseDataElement( ... metainfo=dict(img_id=0, img_shape=(640, 640)), - ... data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))) + ... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) >>> gt_instances.img_shape = (1280, 1280) >>> gt_instances.img_shape # (1280, 1280) >>> gt_instances.bboxes = gt_instances.bboxes * 2 @@ -122,32 +123,75 @@ class BaseDataElement: >>> np_instances = cpu_instances.numpy() >>> # print - >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) - >>> instance_data = BaseDataElement(metainfo=img_meta) - >>> instance_data.det_labels = torch.LongTensor([0, 1, 2, 3]) - >>> instance_data.det_scores = torch.Tensor([0.01, 0.1, 0.2, 0.3]) - >>> print(results) + >>> metainfo = dict(img_shape=(800, 1196, 3)) + >>> gt_instances = BaseDataElement( + >>> metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + >>> sample = BaseDataElement(metainfo=metainfo, + ... gt_instances=gt_instances) + >>> print(sample) <BaseDataElement( - META INFORMATION - img_shape: (800, 1196, 3) - pad_shape: (800, 1216, 3) - DATA FIELDS - shape of det_labels: torch.Size([4]) - shape of det_scores: torch.Size([4]) - ) at 0x7f84acd10f90> + META INFORMATION + img_shape: (800, 1196, 3) + DATA FIELDS + gt_instances: <BaseDataElement( + META INFORMATION + img_shape: (800, 1196, 3) + DATA FIELDS + det_labels: tensor([0, 1, 2, 3]) + ) at 0x7f0ec5eadc70> + ) at 0x7f0fea49e130> + + >>> # inheritance + >>> class DetDataSample(BaseDataElement): + ... @property + ... def proposals(self): + ... return self._proposals + ... @proposals.setter + ... def proposals(self, value): + ... self.set_field(value, '_proposals', dtype=BaseDataElement) + ... @proposals.deleter + ... def proposals(self): + ... del self._proposals + ... @property + ... def gt_instances(self): + ... return self._gt_instances + ... @gt_instances.setter + ... def gt_instances(self, value): + ... self.set_field(value, '_gt_instances', + ... dtype=BaseDataElement) + ... @gt_instances.deleter + ... def gt_instances(self): + ... del self._gt_instances + ... @property + ... def pred_instances(self): + ... return self._pred_instances + ... @pred_instances.setter + ... def pred_instances(self, value): + ... self.set_field(value,'_pred_instances', + ... dtype=BaseDataElement) + ... @pred_instances.deleter + ... def pred_instances(self): + ... del self._pred_instances + >>> det_sample = DetDataSample() + >>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) + >>> det_sample.proposals = proposals + >>> assert 'proposals' in det_sample + >>> assert det_sample.proposals == proposals + >>> del det_sample.proposals + >>> assert 'proposals' not in det_sample + >>> with self.assertRaises(AssertionError): + ... det_sample.proposals = torch.rand((5, 4)) """ - def __init__(self, - metainfo: Optional[dict] = None, - data: Optional[dict] = None) -> None: + def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: self._metainfo_fields: set = set() self._data_fields: set = set() if metainfo is not None: self.set_metainfo(metainfo=metainfo) - if data is not None: - self.set_data(data) + if kwargs: + self.set_data(kwargs) def set_metainfo(self, metainfo: dict) -> None: """Set or change key-value pairs in ``metainfo_field`` by parameter @@ -167,8 +211,7 @@ class BaseDataElement: 'which is immutable. If you want to' 'change the key in data, please use' 'set_data') - self._metainfo_fields.add(k) - self.__dict__[k] = v + self.set_field(name=k, value=v, field_type='metainfo', dtype=None) def set_data(self, data: dict) -> None: """Set or change key-value pairs in ``data_field`` by parameter @@ -181,7 +224,21 @@ class BaseDataElement: assert isinstance(data, dict), f'meta should be a `dict` but got {data}' for k, v in data.items(): - self.__setattr__(k, v) + self.set_field(name=k, value=v, field_type='data', dtype=None) + + def update(self, instance: 'BaseDataElement') -> None: + """The update() method updates the BaseDataElement with the elements + from another BaseDataElement object. + + Args: + instance (BaseDataElement): Another BaseDataElement object for + update the current object. + """ + assert isinstance( + instance, BaseDataElement + ), f'instance should be a `BaseDataElement` but got {type(instance)}' + self.set_metainfo(dict(instance.metainfo_items())) + self.set_data(dict(instance.items())) def new(self, metainfo: dict = None, @@ -197,6 +254,8 @@ class BaseDataElement: Defaults to None. data (dict, optional): A dict contains annotations of image or model predictions. Defaults to None. + Returns: + BaseDataElement: a new data element with same type. """ new_data = self.__class__() @@ -207,10 +266,21 @@ class BaseDataElement: if data is not None: new_data.set_data(data) else: - new_data.set_data(dict(self.data_items())) + new_data.set_data(dict(self.items())) return new_data - def data_keys(self) -> list: + def clone(self): + """Deep copy the current data element. + + Returns: + BaseDataElement: the copy of current data element. + """ + clone_data = self.__class__() + clone_data.set_metainfo(dict(self.metainfo_items())) + clone_data.set_data(dict(self.items())) + return clone_data + + def keys(self) -> list: """ Returns: list: Contains all keys in data_fields. @@ -224,12 +294,12 @@ class BaseDataElement: """ return list(self._metainfo_fields) - def data_values(self) -> list: + def values(self) -> list: """ Returns: list: Contains all values in data. """ - return [getattr(self, k) for k in self.data_keys()] + return [getattr(self, k) for k in self.keys()] def metainfo_values(self) -> list: """ @@ -238,36 +308,36 @@ class BaseDataElement: """ return [getattr(self, k) for k in self.metainfo_keys()] - def keys(self) -> list: + def all_keys(self) -> list: """ Returns: list: Contains all keys in metainfo and data. """ - return self.metainfo_keys() + self.data_keys() + return self.metainfo_keys() + self.keys() - def values(self) -> list: + def all_values(self) -> list: """ Returns: list: Contains all values in metainfo and data. """ - return self.metainfo_values() + self.data_values() + return self.metainfo_values() + self.values() - def items(self) -> Iterator[Tuple[str, Any]]: + def all_items(self) -> Iterator[Tuple[str, Any]]: """ Returns: iterator: an iterator object whose element is (key, value) tuple pairs for ``metainfo`` and ``data``. """ - for k in self.keys(): + for k in self.all_keys(): yield (k, getattr(self, k)) - def data_items(self) -> Iterator[Tuple[str, Any]]: + def items(self) -> Iterator[Tuple[str, Any]]: """ Returns: iterator: an iterator object whose element is (key, value) tuple pairs for ``data``. """ - for k in self.data_keys(): + for k in self.keys(): yield (k, getattr(self, k)) def metainfo_items(self) -> Iterator[Tuple[str, Any]]: @@ -279,11 +349,11 @@ class BaseDataElement: for k in self.metainfo_keys(): yield (k, getattr(self, k)) - def __setattr__(self, name: str, val: Any): + def __setattr__(self, name: str, value: Any): """setattr is only used to set data.""" if name in ('_metainfo_fields', '_data_fields'): if not hasattr(self, name): - super().__setattr__(name, val) + super().__setattr__(name, value) else: raise AttributeError(f'{name} has been used as a ' 'private attribute, which is immutable. ') @@ -292,10 +362,10 @@ class BaseDataElement: raise AttributeError( f'`{name}` is used in meta information.' 'if you want to change the key in metainfo, please use' - 'set_metainfo(dict(name=val))') + '`set_metainfo(dict(name=value))`') - self._data_fields.add(name) - super().__setattr__(name, val) + self.set_field( + name=name, value=value, field_type='data', dtype=None) def __delattr__(self, item: str): if item in ('_metainfo_fields', '_data_fields'): @@ -311,10 +381,9 @@ class BaseDataElement: __setitem__ = __setattr__ __delitem__ = __delattr__ - def get(self, *args) -> Any: + def get(self, key, default=None) -> Any: """get property in data and metainfo as the same as python.""" - assert len(args) < 3, '``get`` get more than 2 arguments' - return self.__dict__.get(*args) + return self.__dict__.get(key, default) def pop(self, *args) -> Any: """pop property in data and metainfo as the same as python.""" @@ -340,113 +409,123 @@ class BaseDataElement: return item in self._data_fields or \ item in self._metainfo_fields + def set_field(self, + value: Any, + name: str, + dtype: Optional[Type] = None, + field_type: str = 'data') -> None: + """Special method for set union field, used as property.setter + functions.""" + assert field_type in ['metainfo', 'data'] + if dtype is not None: + assert isinstance( + value, + dtype), f'{value} should be a {dtype} but got {type(value)}' + + super().__setattr__(name, value) + if field_type == 'metainfo': + self._metainfo_fields.add(name) + else: + self._data_fields.add(name) + # Tensor-like methods def to(self, *args, **kwargs) -> 'BaseDataElement': """Apply same name function to all tensors in data_fields.""" new_data = self.new() - for k, v in self.data_items(): + for k, v in self.items(): if hasattr(v, 'to'): v = v.to(*args, **kwargs) data = {k: v} new_data.set_data(data) - for k, v in self.metainfo_items(): - if hasattr(v, 'to'): - v = v.to(*args, **kwargs) - metainfo = {k: v} - new_data.set_metainfo(metainfo) return new_data # Tensor-like methods def cpu(self) -> 'BaseDataElement': - """Convert all tensors to CPU in metainfo and data.""" + """Convert all tensors to CPU in data.""" new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, torch.Tensor): + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): v = v.cpu() data = {k: v} new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, torch.Tensor): - v = v.cpu() - metainfo = {k: v} - new_data.set_metainfo(metainfo) return new_data # Tensor-like methods def cuda(self) -> 'BaseDataElement': - """Convert all tensors to GPU in metainfo and data.""" + """Convert all tensors to GPU in data.""" new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, torch.Tensor): + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): v = v.cuda() data = {k: v} new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, torch.Tensor): - v = v.cuda() - metainfo = {k: v} - new_data.set_metainfo(metainfo) return new_data # Tensor-like methods def detach(self) -> 'BaseDataElement': - """Detach all tensors in metainfo and data.""" + """Detach all tensors in data.""" new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, torch.Tensor): + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): v = v.detach() data = {k: v} new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, torch.Tensor): - v = v.detach() - metainfo = {k: v} - new_data.set_metainfo(metainfo) return new_data # Tensor-like methods def numpy(self) -> 'BaseDataElement': - """Convert all tensor to np.narray in metainfo and data.""" + """Convert all tensor to np.narray in data.""" new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, torch.Tensor): + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): v = v.detach().cpu().numpy() data = {k: v} new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, torch.Tensor): - v = v.detach().cpu().numpy() - metainfo = {k: v} - new_data.set_metainfo(metainfo) return new_data def to_tensor(self) -> 'BaseDataElement': - """Convert all np.narray to tensor in metainfo and data.""" + """Convert all np.narray to tensor in data.""" new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, np.ndarray): - v = torch.from_numpy(v) - data = {k: v} - new_data.set_data(data) - for k, v in self.metainfo_items(): + for k, v in self.items(): + data = {} if isinstance(v, np.ndarray): v = torch.from_numpy(v) - metainfo = {k: v} - new_data.set_metainfo(metainfo) + data[k] = v + elif isinstance(v, BaseDataElement): + v = v.to_tensor() + data[k] = v + new_data.set_data(data) return new_data def __repr__(self) -> str: - repr = '\n META INFORMATION \n' - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - repr += f'shape of {k}: {v.shape} \n' - else: - repr += f'{k}: {v} \n' - repr += '\n DATA FIELDS \n' - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - repr += f'shape of {k}: {v.shape} \n' + + def _addindent(s_: str, num_spaces: int) -> str: + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) # type: ignore + s = first + '\n' + s # type: ignore + return s # type: ignore + + def dump(obj: Any) -> str: + _repr = '' + if isinstance(obj, dict): + for k, v in obj.items(): + _repr += f'\n{k}: {_addindent(dump(v), 4)}' + elif isinstance(obj, BaseDataElement): + _repr += '\n\n META INFORMATION' + metainfo_items = dict(obj.metainfo_items()) + _repr += _addindent(dump(metainfo_items), 4) + _repr += '\n\n DATA FIELDS' + items = dict(obj.items()) + _repr += _addindent(dump(items), 4) + classname = obj.__class__.__name__ + _repr = f'<{classname}({_repr}\n) at {hex(id(obj))}>' else: - repr += f'{k}: {v} \n' - classname = self.__class__.__name__ - return f'<{classname}({repr}\n) at {hex(id(self))}>' + _repr += repr(obj) + return _repr + + return dump(self) diff --git a/mmengine/data/base_data_sample.py b/mmengine/data/base_data_sample.py deleted file mode 100644 index 659a4a8b..00000000 --- a/mmengine/data/base_data_sample.py +++ /dev/null @@ -1,558 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from typing import Any, Iterator, Optional, Tuple, Type - -import numpy as np -import torch - -from .base_data_element import BaseDataElement - - -class BaseDataSample: - """A base data structure interface of OpenMMlab. - - A sample data consists of input data (such as an image) and its annotations - and predictions. In general, an image can have multiple types of - annotations and/or predictions at the same time (for example, both - pixel-level semantic segmentation annotations and instance-level detection - bboxes annotations). To facilitate data access of multitask, MMEngine - defines ``BaseDataSample`` as the base class for sample data encapsulation. - **The attributes of ``BaseDataSample`` will be various types of data - elements**, and the codebases in OpenMMLab need to implement their own - xxxDataSample such as ClsDataSample, DetDataSample, SegDataSample based on - ``BaseDataSample`` to encapsulate all relevant data, as a data - interface between dataset, model, visualizer, and evaluator components. - - These attributes in ``BaseDataElement`` are divided into two parts, - the ``metainfo`` and the ``data`` respectively. - - - ``metainfo``: Usually contains the - information about the image such as filename, - image_shape, pad_shape, etc. The attributes can be accessed or - modified by dict-like or object-like operations, such as - ``.``(only for data access) , ``in``, ``del``, ``pop(str)``, - ``get(str)``, ``metainfo_keys()``, ``metainfo_values()``, - ``metainfo_items()``, ``set_metainfo()``(for set or change value - in metainfo). - - - ``data``: Annotations or model predictions are - stored. The attributes can be accessed or modified by - dict-like or object-like operations, such as - ``.`` , ``in``, ``del``, ``pop(str)`` ``get(str)``, ``data_keys()``, - ``data_values()``, ``data_items()``. Users can also apply tensor-like - methods to all obj:``torch.Tensor`` in the ``data_fileds``, - such as ``.cuda()``, ``.cpu()``, ``.numpy()``, , ``.to()``, - ``to_tensor()``, ``.detach()``, ``.numpy()`` - - Args: - meta_info (dict, optional): A dict contains the meta information - of a sample. such as ``dict(img_shape=(512, 512, 3), - scale_factor=(1, 1, 1, 1))``. Defaults to None. - data (dict, optional): A dict contains annotations of a sample or - model predictions. Defaults to None. - - Examples: - >>> from mmengine.data import BaseDataElement, BaseDataSample - >>> gt_instances = BaseDataSample() - >>> bboxes = torch.rand((5, 4)) - >>> scores = torch.rand((5,)) - >>> img_id = 0 - >>> img_shape = (800, 1333) - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=img_id, img_shape=img_shape), - ... data=dict(bboxes=bboxes, scores=scores)) - >>> data = dict(gt_instances=gt_instances) - >>> sample = BaseDataSample( - ... metainfo=dict(img_id=img_id, img_shape=img_shape), - ... data=data) - >>> sample = BaseDataSample(dict(img_id=img_id, - ... img_shape=(H, W))) - - >>> # new - >>> data1 = dict(bboxes=torch.rand((5, 4)), - scores=torch.rand((5,))) - >>> metainfo1 = dict(img_id=1, img_shape=(640, 640)), - >>> gt_instances1 = BaseDataElement( - ... metainfo=metainfo1, - ... data=data1) - >>> sample1 = sample.new( - ... metainfo=metainfo1 - ... data=dict(gt_instances1=gt_instances1)), - >>> gt_instances2 = gt_instances1.new() - - >>> # property add and access - >>> sample = BaseDataSample() - >>> gt_instances = BaseDataElement( - ... metainfo=dict(img_id=9, img_shape=(100, 100)), - ... data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,))) - >>> sample.set_metainfo(dict(img_id=9, img_shape=(100, 100)) - >>> assert 'img_shape' in sample.metainfo_keys() - >>> assert 'img_shape' in sample - >>> assert 'img_shape' not in sample.data_keys() - >>> assert 'img_shape' in sample.keys() - >>> print(sample.img_shape) - >>> gt_instances.gt_instances = gt_instances - >>> assert 'gt_instances' in sample.data_keys() - >>> assert 'gt_instances' in sample - >>> assert 'gt_instances' in sample.keys() - >>> assert 'gt_instances' not in sample.metainfo_keys() - >>> print(sample.gt_instances) - >>> pred_instances = BaseDataElement( - ... metainfo=dict(img_id=9, img_shape=(100, 100)), - ... data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,)) - >>> sample.pred_instances = pred_instances - >>> assert 'pred_instances' in sample.data_keys() - >>> assert 'pred_instances' in sample - >>> assert 'pred_instances' in sample.keys() - >>> assert 'pred_instances' not in sample.metainfo_keys() - >>> print(sample.pred_instances) - - >>> # property delete and change - >>> metainfo=dict(img_id=0, img_shape=(640, 640) - >>> gt_instances = BaseDataElement( - ... metainfo=metainfo), - ... data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,)))) - >>> sample = BaseDataSample(metainfo=metainfo, - ... data=dict(gt_instances=gt_instances)) - >>> sample.img_shape = (1280, 1280) - >>> sample.img_shape # (1280, 1280) - >>> sample.gt_instances = gt_instances - >>> sample.get('img_shape', None) # (640, 640) - >>> sample.get('gt_instances', None) - >>> del sample.img_shape - >>> del sample.gt_instances - >>> assert 'img_shape' not in sample - >>> assert 'gt_instances' not in sample - >>> sample.pop('img_shape', None) # None - >>> sample.pop('gt_instances', None) # None - - >>> # Tensor-like - >>> cuda_sample = gt_instasamplences.cuda() - >>> cuda_sample = gt_sample.to('cuda:0') - >>> cpu_sample = cuda_sample.cpu() - >>> cpu_sample = cuda_sample.to('cpu') - >>> fp16_sample = cuda_sample.to( - ... device=None, dtype=torch.float16, non_blocking=False, copy=False, - ... memory_format=torch.preserve_format) - >>> cpu_sample = cuda_sample.detach() - >>> np_sample = cpu_sample.numpy() - - >>> # print - >>> metainfo = dict(img_shape=(800, 1196, 3)) - >>> gt_instances = BaseDataElement( - ... metainfo=metainfo, - ... data=dict(det_labels=torch.LongTensor([0, 1, 2, 3]))) - >>> data = dict(gt_instances=gt_instances) - >>> sample = BaseDataSample(metainfo=metainfo, data=data) - >>> print(sample) - <BaseDataSample(' - META INFORMATION ' - img_shape: (800, 1196, 3) ' - DATA FIELDS ' - gt_instances:<BaseDataElement(' - META INFORMATION ' - img_shape: (800, 1196, 3) ' - DATA FIELDS ' - shape of det_labels: torch.Size([4]) ' - ) at 0x7f9705daecd0>' - ) at 0x7f981e41c550>' - - >>> # inheritance - >>> class DetDataSample(BaseDataSample): - ... proposals = property( - ... fget=partial(BaseDataSample.get_field, name='_proposals'), - ... fset=partial( - ... BaseDataSample.set_field, - ... name='_proposals', - ... dtype=BaseDataElement), - ... fdel=partial(BaseDataSample.del_field, name='_proposals'), - ... doc='Region proposals of an image') - ... gt_instances = property( - ... fget=partial(BaseDataSample.get_field, - ... name='_gt_instances'), - ... fset=partial( - ... BaseDataSample.set_field, - ... name='_gt_instances', - ... dtype=BaseDataElement), - ... fdel=partial(BaseDataSample.del_field, - ... name='_gt_instances'), - ... doc='Ground truth instances of an image') - ... pred_instances = property( - ... fget=partial( - ... BaseDataSample.get_field, name='_pred_instances'), - ... fset=partial( - ... BaseDataSample.set_field, - ... name='_pred_instances', - ... dtype=BaseDataElement), - ... fdel=partial( - ... BaseDataSample.del_field, name='_pred_instances'), - ... doc='Predicted instances of an image') - >>> det_sample = DetDataSample() - >>> proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4)))) - >>> det_sample.proposals = proposals - >>> assert 'proposals' in det_sample - >>> assert det_sample.proposals == proposals - >>> del det_sample.proposals - >>> assert 'proposals' not in det_sample - >>> with self.assertRaises(AssertionError): - ... det_sample.proposals = torch.rand((5, 4)) - """ - - def __init__(self, - metainfo: Optional[dict] = None, - data: Optional[dict] = None) -> None: - - self._metainfo_fields: set = set() - self._data_fields: set = set() - - if metainfo is not None: - self.set_metainfo(metainfo=metainfo) - if data is not None: - self.set_data(data) - - def set_metainfo(self, metainfo: dict) -> None: - """Set or change key-value pairs in ``metainfo_field`` by parameter - ``metainfo``. - - Args: - metainfo (dict): A dict contains the meta information - of image, such as ``img_shape``, ``scale_factor``, etc. - """ - assert isinstance( - metainfo, dict), f'meta should be a ``dict`` but got {metainfo}' - meta = copy.deepcopy(metainfo) - for k, v in meta.items(): - if k in self._data_fields: - raise AttributeError(f'`{k}` is used in data,' - 'which is immutable. If you want to' - 'change the key in data, please use' - 'set_data') - self.set_field(name=k, value=v, field_type='metainfo', dtype=None) - - def set_data(self, data: dict) -> None: - """Set or change key-value pairs in ``data_field`` by parameter - ``data``. - - Args: - data (dict): A dict contains annotations of image or - model predictions. Defaults to None. - """ - assert isinstance(data, - dict), f'meta should be a ``dict`` but got {data}' - for k, v in data.items(): - self.set_field(name=k, value=v, field_type='data', dtype=None) - - def new(self, - metainfo: Optional[dict] = None, - data: Optional[dict] = None) -> 'BaseDataSample': - """Return a new data element with same type. If ``metainfo`` and - ``data`` are None, the new data element will have same metainfo and - data. If metainfo or data is not None, the new results will overwrite - it with the input value. - - Args: - metainfo (dict, optional): A dict contains the meta information - of image. such as ``img_shape``, ``scale_factor``, etc. - Defaults to None. - data (dict, optional): A dict contains annotations of image or - model predictions. Defaults to None. - """ - new_data = self.__class__() - - if metainfo is not None: - new_data.set_metainfo(metainfo) - else: - new_data.set_metainfo(dict(self.metainfo_items())) - if data is not None: - new_data.set_data(data) - else: - new_data.set_data(dict(self.data_items())) - return new_data - - def data_keys(self) -> list: - """ - Returns: - list: Contains all keys in data_fields. - """ - return list(self._data_fields) - - def metainfo_keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo_fields. - """ - return list(self._metainfo_fields) - - def data_values(self) -> list: - """ - Returns: - list: Contains all values in data. - """ - return [getattr(self, k) for k in self.data_keys()] - - def metainfo_values(self) -> list: - """ - Returns: - list: Contains all values in metainfo. - """ - return [getattr(self, k) for k in self.metainfo_keys()] - - def keys(self) -> list: - """ - Returns: - list: Contains all keys in metainfo and data. - """ - return self.metainfo_keys() + self.data_keys() - - def values(self) -> list: - """ - Returns: - list: Contains all values in metainfo and data. - """ - return self.metainfo_values() + self.data_values() - - def items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: an iterator object whose element is (key, value) tuple - pairs for ``metainfo`` and ``data``. - """ - for k in self.keys(): - yield (k, getattr(self, k)) - - def data_items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: an iterator object whose element is (key, value) tuple - pairs for ``data``. - """ - - for k in self.data_keys(): - yield (k, getattr(self, k)) - - def metainfo_items(self) -> Iterator[Tuple[str, Any]]: - """ - Returns: - iterator: an iterator object whose element is (key, value) tuple - pairs for ``metainfo``. - """ - for k in self.metainfo_keys(): - yield (k, getattr(self, k)) - - def __setattr__(self, name: str, value: Any): - """setattr is only used to set data.""" - if name in ('_metainfo_fields', '_data_fields'): - if not hasattr(self, name): - super().__setattr__(name, value) - else: - raise AttributeError( - f'{name} has been used as a ' - f'private attribute, which is immutable. ') - else: - if name in self._metainfo_fields: - raise AttributeError( - f'``{name}`` is used in meta information.' - 'If you want to change the key in metainfo, please use' - 'set_metainfo(dict(name=val))') - - self.set_field( - name=name, value=value, field_type='data', dtype=None) - - def __delattr__(self, item: str): - if item in ('_metainfo_fields', '_data_fields'): - raise AttributeError(f'{item} has been used as a ' - 'private attribute, which is immutable. ') - super().__delattr__(item) - if item in self._metainfo_fields: - self._metainfo_fields.remove(item) - else: - self._data_fields.remove(item) - - # dict-like methods - __setitem__ = __setattr__ - __delitem__ = __delattr__ - - def get(self, *args) -> Any: - """get property in data and metainfo as the same as python.""" - assert len(args) < 3, f'``get`` get more than 2 arguments {args}' - return self.__dict__.get(*args) - - def pop(self, *args) -> Any: - """pop property in data and metainfo as the same as python.""" - assert len(args) < 3, '``pop`` get more than 2 arguments' - name = args[0] - if name in self._metainfo_fields: - self._metainfo_fields.remove(args[0]) - return self.__dict__.pop(*args) - - elif name in self._data_fields: - self._data_fields.remove(args[0]) - return self.__dict__.pop(*args) - - # with default value - elif len(args) == 2: - return args[1] - else: - # don't just use 'self.__dict__.pop(*args)' for only popping key in - # metainfo or data - raise KeyError(f'{args[0]} is not contained in metainfo or data') - - def __contains__(self, item: str) -> bool: - return item in self._data_fields or \ - item in self._metainfo_fields - - def get_field(self, name: str) -> Any: - """Special method for get union field, used as property.getter - functions.""" - return getattr(self, name) - - # It's must to keep the parameters order ``value``, ``name``, for - # ``partial(BaseDataSample.set_field, - # name='_proposals', dtype=BaseDataElement)`` - def set_field(self, - value: Any, - name: str, - dtype: Optional[Type] = None, - field_type: str = 'data') -> None: - """Special method for set union field, used as property.setter - functions.""" - assert field_type in ['metainfo', 'data'] - if dtype is not None: - assert isinstance( - value, - dtype), f'{value} should be a {dtype} but got {type(value)}' - - super().__setattr__(name, value) - if field_type == 'metainfo': - self._metainfo_fields.add(name) - else: - self._data_fields.add(name) - - def del_field(self, name: str) -> None: - """Special method for deleting union field, used as property.deleter - functions.""" - self.__delattr__(name) - - # Tensor-like methods - def to(self, *args, **kwargs) -> 'BaseDataSample': - """Apply same name function to all tensors in data_fields.""" - new_data = self.new() - for k, v in self.data_items(): - if hasattr(v, 'to'): - v = v.to(*args, **kwargs) - data = {k: v} - new_data.set_data(data) - for k, v in self.metainfo_items(): - if hasattr(v, 'to'): - v = v.to(*args, **kwargs) - metainfo = {k: v} - new_data.set_metainfo(metainfo) - return new_data - - # Tensor-like methods - def cpu(self) -> 'BaseDataSample': - """Convert all tensors to CPU in metainfo and data.""" - new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cpu() - data = {k: v} - new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cpu() - metainfo = {k: v} - new_data.set_metainfo(metainfo) - return new_data - - # Tensor-like methods - def cuda(self) -> 'BaseDataSample': - """Convert all tensors to GPU in metainfo and data.""" - new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cuda() - data = {k: v} - new_data.set_data(data) - - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.cuda() - metainfo = {k: v} - new_data.set_metainfo(metainfo) - return new_data - - # Tensor-like methods - def detach(self) -> 'BaseDataSample': - """Detach all tensors in metainfo and data.""" - new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach() - data = {k: v} - new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach() - metainfo = {k: v} - new_data.set_metainfo(metainfo) - return new_data - - # Tensor-like methods - def numpy(self) -> 'BaseDataSample': - """Convert all tensor to np.narray in metainfo and data.""" - new_data = self.new() - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach().cpu().numpy() - data = {k: v} - new_data.set_data(data) - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, BaseDataElement)): - v = v.detach().cpu().numpy() - metainfo = {k: v} - new_data.set_metainfo(metainfo) - return new_data - - def to_tensor(self) -> 'BaseDataSample': - """Convert all np.narray to tensor in metainfo and data.""" - new_data = self.new() - for k, v in self.data_items(): - data = {} - if isinstance(v, np.ndarray): - v = torch.from_numpy(v) - data.update({k: v}) - elif isinstance(v, (BaseDataElement, BaseDataSample)): - v = v.to_tensor() - data.update({k: v}) - new_data.set_data(data) - for k, v in self.metainfo_items(): - data = {} - if isinstance(v, np.ndarray): - v = torch.from_numpy(v) - data.update({k: v}) - elif isinstance(v, (BaseDataElement, BaseDataSample)): - v = v.to_tensor() - data.update({k: v}) - new_data.set_metainfo(data) - return new_data - - def __repr__(self) -> str: - _repr = '\n META INFORMATION \n' - for k, v in self.metainfo_items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - _repr += f'shape of {k}: {v.shape} \n' - elif isinstance(v, (BaseDataElement, BaseDataSample)): - _repr += f'{k}:{repr(v)}\n' - else: - _repr += f'{k}: {v} \n' - _repr += '\n DATA FIELDS \n' - for k, v in self.data_items(): - if isinstance(v, (torch.Tensor, np.ndarray)): - _repr += f'shape of {k}: {v.shape} \n' - elif isinstance(v, (BaseDataElement, BaseDataSample)): - _repr += f'{k}:{repr(v)}\n' - else: - _repr += f'{k}: {v} \n' - classname = self.__class__.__name__ - return f'<{classname}({_repr}\n) at {hex(id(self))}>' diff --git a/mmengine/data/utils.py b/mmengine/data/utils.py index b4edba99..0f569d39 100644 --- a/mmengine/data/utils.py +++ b/mmengine/data/utils.py @@ -5,9 +5,9 @@ from typing import Any, Sequence, Tuple import numpy as np import torch -from .base_data_sample import BaseDataSample +from .base_data_element import BaseDataElement -DATA_BATCH = Sequence[Tuple[Any, BaseDataSample]] +DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]] def worker_init_fn(worker_id: int, num_workers: int, rank: int, @@ -36,10 +36,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH: nothing just returns ``data_batch``. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data from + data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from dataloader. Returns: - Sequence[Tuple[Any, BaseDataSample]]: Return input ``data_batch``. + Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``. """ return data_batch diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py index 5354df90..2f31444b 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/base.py @@ -3,7 +3,7 @@ import warnings from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) @@ -51,16 +51,16 @@ class BaseEvaluator(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]], - predictions: Sequence[BaseDataSample]) -> None: + def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + predictions: Sequence[BaseDataElement]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data from the dataloader. - predictions (Sequence[BaseDataSample]): A batch of outputs from + predictions (Sequence[BaseDataElement]): A batch of outputs from the model. """ diff --git a/mmengine/evaluator/composed_evaluator.py b/mmengine/evaluator/composed_evaluator.py index 7e76e731..b965b358 100644 --- a/mmengine/evaluator/composed_evaluator.py +++ b/mmengine/evaluator/composed_evaluator.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from .base import BaseEvaluator @@ -32,14 +32,14 @@ class ComposedEvaluator: for evaluator in self.evaluators: evaluator.dataset_meta = dataset_meta - def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]], - predictions: Sequence[BaseDataSample]): + def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + predictions: Sequence[BaseDataElement]): """Invoke process method of each wrapped evaluator. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): A batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data from the dataloader. - predictions (Sequence[BaseDataSample]): A batch of outputs from + predictions (Sequence[BaseDataElement]): A batch of outputs from the model. """ diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 81e7f36d..a373bbb5 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -4,13 +4,13 @@ import warnings from pathlib import Path from typing import Any, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -185,7 +185,7 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 45a1df11..c793f01b 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -3,11 +3,11 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -39,14 +39,14 @@ class EmptyCacheHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, - Sequence[BaseDataSample]]] = None, + Sequence[BaseDataElement]]] = None, mode: str = 'train') -> None: """Empty cache after an iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index bafe5f31..84060334 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] class Hook: @@ -174,7 +174,7 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. """ self._before_iter( @@ -190,7 +190,7 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. """ self._before_iter( @@ -206,7 +206,7 @@ class Hook: Args: runner (Runner): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. """ self._before_iter( @@ -223,7 +223,7 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. @@ -239,7 +239,7 @@ class Hook: runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) \ + outputs: Optional[Sequence[BaseDataElement]] = None) \ -> None: """All subclasses should override this method, if they need any operations after each validation iteration. @@ -247,7 +247,7 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. @@ -264,14 +264,14 @@ class Hook: runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: """All subclasses should override this method, if they need any operations after each test iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. @@ -317,7 +317,7 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ @@ -327,7 +327,7 @@ class Hook: runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Union[Sequence[BaseDataSample], + outputs: Optional[Union[Sequence[BaseDataElement], dict]] = None, mode: str = 'train') -> None: """All subclasses should override this method, if they need any @@ -337,9 +337,9 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index a18231bd..bf123cae 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -2,11 +2,11 @@ import time from typing import Any, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -37,7 +37,7 @@ class IterTimerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ @@ -50,14 +50,14 @@ class IterTimerHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, - Sequence[BaseDataSample]]] = None, + Sequence[BaseDataElement]]] = None, mode: str = 'train') -> None: """Logging time for a iteration and update the time flag. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. outputs (dict or sequence, optional): Outputs from model. Defaults to None. diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 492716c5..786fd311 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -9,14 +9,14 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -183,7 +183,7 @@ class LoggerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[BaseDataSample], optional): Data from + data_batch (Sequence[BaseDataElement], optional): Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index a3a3b092..2e05fc59 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Sequence, Tuple import cv2 import numpy as np -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.hooks import Hook from mmengine.registry import HOOKS from mmengine.utils.misc import tensor2imgs @@ -41,16 +41,16 @@ class NaiveVisualizationHook(Hook): self, runner, batch_idx: int, - data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, - outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None, + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. - outputs (Sequence[BaseDataSample], optional): Outputs from model. + outputs (Sequence[BaseDataElement], optional): Outputs from model. Defaults to None. """ if self.every_n_iters(runner, self._interval): diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 03b7d606..ff33b54a 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -6,11 +6,11 @@ import torch from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -77,7 +77,7 @@ class OptimizerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index f3531cd2..9522abcf 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Optional, Sequence, Tuple -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() @@ -25,7 +25,7 @@ class ParamSchedulerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index b725dfcc..9e04d640 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.evaluator import BaseEvaluator, build_evaluator from mmengine.registry import LOOPS from mmengine.utils import is_list_of @@ -63,11 +63,11 @@ class EpochBasedTrainLoop(BaseLoop): self.runner._epoch += 1 def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataSample]]) -> None: + data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: """Iterate one min-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from dataloader. """ self.runner.call_hook( @@ -131,11 +131,11 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train') def run_iter(self, data_batch: Sequence[Tuple[Any, - BaseDataSample]]) -> None: + BaseDataElement]]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from dataloader. """ self.runner.call_hook( @@ -201,16 +201,16 @@ class ValLoop(BaseLoop): self.runner.call_hook('after_val') @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataSample]]): + def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]): """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from dataloader. """ self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) - # outputs should be sequence of BaseDataSample + # outputs should be sequence of BaseDataElement outputs = self.runner.model(data_batch) self.evaluator.process(data_batch, outputs) self.runner.call_hook( @@ -259,16 +259,16 @@ class TestLoop(BaseLoop): @torch.no_grad() def run_iter(self, idx, - data_batch: Sequence[Tuple[Any, BaseDataSample]]) -> None: + data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: """Iterate one mini-batch. Args: - data_batch (Sequence[Tuple[Any, BaseDataSample]]): Batch of data + data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from dataloader. """ self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataSample + # predictions should be sequence of BaseDataElement predictions = self.runner.model(data_batch) self.evaluator.process(data_batch, predictions) self.runner.call_hook( diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 03616a83..ae6ff113 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -12,7 +12,7 @@ from matplotlib.collections import (LineCollection, PatchCollection, from matplotlib.figure import Figure from matplotlib.patches import Circle -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.registry import VISUALIZERS from .utils import (check_type, check_type_and_length, tensor2ndarray, value2list) @@ -113,8 +113,8 @@ class Visualizer: >>> pass >>> def draw(self, >>> image: Optional[np.ndarray] = None, - >>> gt_sample: Optional['BaseDataSample'] = None, - >>> pred_sample: Optional['BaseDataSample'] = None, + >>> gt_sample: Optional['BaseDataElement'] = None, + >>> pred_sample: Optional['BaseDataElement'] = None, >>> show_gt: bool = True, >>> show_pred: bool = True) -> None: >>> pass @@ -131,8 +131,8 @@ class Visualizer: def draw(self, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True) -> None: pass diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py index 39e622fb..c4f548eb 100644 --- a/mmengine/visualization/writer.py +++ b/mmengine/visualization/writer.py @@ -9,7 +9,7 @@ import cv2 import numpy as np import torch -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.fileio import dump from mmengine.registry import VISUALIZERS, WRITERS from mmengine.utils import TORCH_VERSION, ManagerMixin @@ -90,8 +90,8 @@ class BaseWriter(metaclass=ABCMeta): def add_image(self, name: str, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -102,10 +102,10 @@ class BaseWriter(metaclass=ABCMeta): name (str): The unique identifier for the image to save. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default: True. draw_pred (bool): Whether to draw the predicted result. Default to True. @@ -228,8 +228,8 @@ class LocalWriter(BaseWriter): def add_image(self, name: str, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -240,10 +240,10 @@ class LocalWriter(BaseWriter): name (str): The unique identifier for the image to save. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default to True. draw_pred (bool): Whether to draw the predicted result. Default to True. @@ -411,8 +411,8 @@ class WandbWriter(BaseWriter): def add_image(self, name: str, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -423,10 +423,10 @@ class WandbWriter(BaseWriter): name (str): The unique identifier for the image to save. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default: True. draw_pred (bool): Whether to draw the predicted result. Default to True. @@ -475,8 +475,8 @@ class WandbWriter(BaseWriter): def add_image_to_wandb(self, name: str, image: np.ndarray, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -487,10 +487,10 @@ class WandbWriter(BaseWriter): name (str): The unique identifier for the image to save. image (np.ndarray): The image to be saved. The format should be BGR. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default to True. draw_pred (bool): Whether to draw the predicted result. Default to True. @@ -608,8 +608,8 @@ class TensorboardWriter(BaseWriter): def add_image(self, name: str, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -620,10 +620,10 @@ class TensorboardWriter(BaseWriter): name (str): The unique identifier for the image to save. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default to True. draw_pred (bool): Whether to draw the predicted result. Default to True. @@ -756,8 +756,8 @@ class ComposedWriter(ManagerMixin): def add_image(self, name: str, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, step: int = 0, @@ -768,10 +768,10 @@ class ComposedWriter(ManagerMixin): name (str): The unique identifier for the image to save. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. - gt_sample (:obj:`BaseDataSample`, optional): The ground truth data + gt_sample (:obj:`BaseDataElement`, optional): The ground truth data structure of OpenMMlab. Default to None. - pred_sample (:obj:`BaseDataSample`, optional): The predicted result - data structure of OpenMMlab. Default to None. + pred_sample (:obj:`BaseDataElement`, optional): The predicted + result data structure of OpenMMlab. Default to None. draw_gt (bool): Whether to draw the ground truth. Default to True. draw_pred (bool): Whether to draw the predicted result. Default to True. diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py index 6d65b493..275b8a3c 100644 --- a/tests/test_data/test_data_element.py +++ b/tests/test_data/test_data_element.py @@ -15,12 +15,17 @@ class TestBaseDataElement(TestCase): metainfo = dict( img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) - data = dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + gt_instances = BaseDataElement( + bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) + pred_instances = BaseDataElement( + bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data def is_equal(self, x, y): assert type(x) == type(y) - if isinstance(x, (int, float, str, list, tuple, dict, set)): + if isinstance( + x, (int, float, str, list, tuple, dict, set, BaseDataElement)): return x == y elif isinstance(x, (torch.Tensor, np.ndarray)): return (x == y).all() @@ -30,9 +35,9 @@ class TestBaseDataElement(TestCase): if metainfo: for k, v in metainfo.items(): assert k in instances - assert k in instances.keys() + assert k in instances.all_keys() assert k in instances.metainfo_keys() - assert k not in instances.data_keys() + assert k not in instances.keys() assert self.is_equal(instances.get(k), v) assert self.is_equal(getattr(instances, k), v) if data: @@ -40,20 +45,31 @@ class TestBaseDataElement(TestCase): assert k in instances assert k in instances.keys() assert k not in instances.metainfo_keys() - assert k in instances.data_keys() + assert k in instances.all_keys() assert self.is_equal(instances.get(k), v) assert self.is_equal(getattr(instances, k), v) def check_data_device(self, instances, device): # assert instances.device == device - for v in instances.data_values(): + for v in instances.values(): if isinstance(v, torch.Tensor): assert v.device == torch.device(device) + elif isinstance(v, BaseDataElement): + self.check_data_device(v, device) def check_data_dtype(self, instances, dtype): - for v in instances.data_values(): + for v in instances.values(): if isinstance(v, (torch.Tensor, np.ndarray)): assert isinstance(v, dtype) + if isinstance(v, BaseDataElement): + self.check_data_dtype(v, dtype) + + def check_requires_grad(self, instances): + for v in instances.values(): + if isinstance(v, torch.Tensor): + assert v.requires_grad is False + if isinstance(v, BaseDataElement): + self.check_requires_grad(v) def test_init(self): # initialization with no data and metainfo @@ -68,24 +84,19 @@ class TestBaseDataElement(TestCase): # initialization with kwargs metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo=metainfo, data=data) - self.check_key_value(instances, metainfo, data) - - # initialization with args - metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) self.check_key_value(instances, metainfo, data) # initialization with args metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo) self.check_key_value(instances, metainfo) - instances = BaseDataElement(data=data) + instances = BaseDataElement(**data) self.check_key_value(instances, data=data) def test_new(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo=metainfo, data=data) + instances = BaseDataElement(metainfo=metainfo, **data) # test new() with no arguments new_instances = instances.new() @@ -95,22 +106,29 @@ class TestBaseDataElement(TestCase): # element and will have new address _, data = self.setup_data() new_instances.set_data(data) - assert not self.is_equal(new_instances.bboxes, instances.bboxes) + assert not self.is_equal(new_instances.gt_instances, + instances.gt_instances) self.check_key_value(new_instances, metainfo, data) # test new() with arguments metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo, data=data) assert type(new_instances) == type(instances) - assert id(new_instances.bboxes) != id(instances.bboxes) + assert id(new_instances.gt_instances) != id(instances.gt_instances) _, new_data = self.setup_data() new_instances.set_data(new_data) - assert id(new_instances.bboxes) != id(data['bboxes']) + assert id(new_instances.gt_instances) != id(data['gt_instances']) self.check_key_value(new_instances, metainfo, new_data) metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo) + def test_clone(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo, **data) + new_instances = instances.clone() + assert type(new_instances) == type(instances) + def test_set_metainfo(self): metainfo, _ = self.setup_data() instances = BaseDataElement() @@ -125,7 +143,7 @@ class TestBaseDataElement(TestCase): # test have the same key in data _, data = self.setup_data() - instances = BaseDataElement(data=data) + instances = BaseDataElement(**data) _, data = self.setup_data() with self.assertRaises(AttributeError): instances.set_metainfo(data) @@ -137,8 +155,8 @@ class TestBaseDataElement(TestCase): metainfo, data = self.setup_data() instances = BaseDataElement() - instances.bboxes = data['bboxes'] - instances.scores = data['scores'] + instances.gt_instances = data['gt_instances'] + instances.pred_instances = data['pred_instances'] self.check_key_value(instances, data=data) # a.xx only set data rather than metainfo @@ -147,7 +165,7 @@ class TestBaseDataElement(TestCase): self.check_key_value(instances, data=metainfo) metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) with self.assertRaises(AttributeError): instances.img_shape = metainfo['img_shape'] @@ -160,10 +178,20 @@ class TestBaseDataElement(TestCase): with self.assertRaises(AssertionError): instances.set_data(123) + def test_update(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo, **data) + proposals = BaseDataElement( + bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + new_instances = BaseDataElement(proposals=proposals) + instances.update(new_instances) + self.check_key_value(instances, metainfo, + data.update(dict(proposals=proposals))) + def test_delete_modify(self): random.seed(10) metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) new_metainfo, new_data = self.setup_data() # avoid generating same metainfo, data @@ -171,38 +199,36 @@ class TestBaseDataElement(TestCase): if new_metainfo['img_id'] == metainfo['img_id'] or new_metainfo[ 'img_shape'] == metainfo['img_shape']: new_metainfo, new_data = self.setup_data() - elif self.is_equal(new_data['bboxes'], - data['bboxes']) or self.is_equal( - new_data['scores'], data['scores']): - new_metainfo, new_data = self.setup_data() else: break - instances.bboxes = new_data['bboxes'] - instances.scores = new_data['scores'] + instances.gt_instances = new_data['gt_instances'] + instances.pred_instances = new_data['pred_instances'] # a.xx only set data rather than metainfo instances.set_metainfo(new_metainfo) self.check_key_value(instances, new_metainfo, new_data) - assert not self.is_equal(instances.bboxes, data['bboxes']) - assert not self.is_equal(instances.scores, data['scores']) + assert not self.is_equal(instances.gt_instances, data['gt_instances']) + assert not self.is_equal(instances.pred_instances, + data['pred_instances']) assert not self.is_equal(instances.img_id, metainfo['img_id']) assert not self.is_equal(instances.img_shape, metainfo['img_shape']) - del instances.bboxes + del instances.gt_instances del instances.img_id - assert not self.is_equal(instances.pop('scores', None), data['scores']) + assert not self.is_equal( + instances.pop('pred_instances', None), data['pred_instances']) with self.assertRaises(AttributeError): - del instances.scores + del instances.pred_instances - assert 'bboxes' not in instances - assert 'scores' not in instances + assert 'gt_instances' not in instances + assert 'pred_instances' not in instances assert 'img_id' not in instances - assert instances.pop('bboxes', None) is None + assert instances.pop('gt_instances', None) is None # test pop not exist key without default with self.assertRaises(KeyError): - instances.pop('bboxes') - assert instances.pop('scores', 'abcdef') == 'abcdef' + instances.pop('gt_instances') + assert instances.pop('pred_instances', 'abcdef') == 'abcdef' assert instances.pop('img_id', None) is None # test pop not exist key without default @@ -220,21 +246,7 @@ class TestBaseDataElement(TestCase): not torch.cuda.is_available(), reason='GPU is required!') def test_cuda(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo=metainfo, data=data) - - cuda_instances = instances.cuda() - self.check_data_device(cuda_instances, 'cuda:0') - - # here we further test to convert from cuda to cpu - cpu_instances = cuda_instances.cpu() - self.check_data_device(cpu_instances, 'cpu') - del cuda_instances - - cuda_instances = instances.to('cuda:0') - self.check_data_device(cuda_instances, 'cuda:0') - - _, data = self.setup_data() - instances = BaseDataElement(metainfo=data) + instances = BaseDataElement(metainfo=metainfo, **data) cuda_instances = instances.cuda() self.check_data_device(cuda_instances, 'cuda:0') @@ -249,35 +261,17 @@ class TestBaseDataElement(TestCase): def test_cpu(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) - self.check_data_device(instances, 'cpu') - - cpu_instances = instances.cpu() - # assert cpu_instances.device == 'cpu' - assert cpu_instances.bboxes.device == torch.device('cpu') - assert cpu_instances.scores.device == torch.device('cpu') - - _, data = self.setup_data() - instances = BaseDataElement(metainfo=data) + instances = BaseDataElement(metainfo=metainfo, **data) self.check_data_device(instances, 'cpu') cpu_instances = instances.cpu() # assert cpu_instances.device == 'cpu' - assert cpu_instances.bboxes.device == torch.device('cpu') - assert cpu_instances.scores.device == torch.device('cpu') + assert cpu_instances.gt_instances.bboxes.device == torch.device('cpu') + assert cpu_instances.gt_instances.labels.device == torch.device('cpu') def test_numpy_tensor(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) - - np_instances = instances.numpy() - self.check_data_dtype(np_instances, np.ndarray) - - tensor_instances = np_instances.to_tensor() - self.check_data_dtype(tensor_instances, torch.Tensor) - - _, data = self.setup_data() - instances = BaseDataElement(metainfo=data) + instances = BaseDataElement(metainfo=metainfo, **data) np_instances = instances.numpy() self.check_data_dtype(np_instances, np.ndarray) @@ -285,77 +279,141 @@ class TestBaseDataElement(TestCase): tensor_instances = np_instances.to_tensor() self.check_data_dtype(tensor_instances, torch.Tensor) - def check_requires_grad(self, instances): - for v in instances.data_values(): - if isinstance(v, torch.Tensor): - assert v.requires_grad is False - def test_detach(self): metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) - instances.detach() - self.check_requires_grad(instances) - - _, data = self.setup_data() - instances = BaseDataElement(metainfo=data) + instances = BaseDataElement(metainfo=metainfo, **data) instances.detach() self.check_requires_grad(instances) def test_repr(self): metainfo = dict(img_shape=(800, 1196, 3)) - data = dict(det_labels=torch.LongTensor([0, 1, 2, 3])) - instances = BaseDataElement(metainfo=metainfo, data=data) - address = hex(id(instances)) - assert repr(instances) == (f'<BaseDataElement(' - f'\n META INFORMATION \n' - f'img_shape: (800, 1196, 3) \n' - f'\n DATA FIELDS \n' - f'shape of det_labels: torch.Size([4]) \n' - f'\n) at {address}>') - metainfo = dict(img_shape=(800, 1196, 3)) - data = dict(det_labels=torch.LongTensor([0, 1, 2, 3])) - instances = BaseDataElement(data=metainfo, metainfo=data) - address = hex(id(instances)) - assert repr(instances) == (f'<BaseDataElement(' - f'\n META INFORMATION \n' - f'shape of det_labels: torch.Size([4]) \n' - f'\n DATA FIELDS \n' - f'img_shape: (800, 1196, 3) \n' - f'\n) at {address}>') + gt_instances = BaseDataElement( + metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances) + address = hex(id(sample)) + address_gt_instances = hex(id(sample.gt_instances)) + assert repr(sample) == ( + '<BaseDataElement(\n\n' + ' META INFORMATION\n' + ' img_shape: (800, 1196, 3)\n\n' + ' DATA FIELDS\n' + ' gt_instances: <BaseDataElement(\n \n' + ' META INFORMATION\n' + ' img_shape: (800, 1196, 3)\n \n' + ' DATA FIELDS\n' + ' det_labels: tensor([0, 1, 2, 3])\n' + f' ) at {address_gt_instances}>\n' + f') at {address}>') + + def test_set_fields(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo) + for key, value in data.items(): + instances.set_field(name=key, value=value, dtype=BaseDataElement) + self.check_key_value(instances, data=data) + + # test type check + _, data = self.setup_data() + instances = BaseDataElement() + for key, value in data.items(): + with self.assertRaises(AssertionError): + instances.set_field(name=key, value=value, dtype=torch.Tensor) + + def test_inheritance(self): + + class DetDataSample(BaseDataElement): + + @property + def proposals(self): + return self._proposals + + @proposals.setter + def proposals(self, value): + self.set_field( + value=value, name='_proposals', dtype=BaseDataElement) + + @proposals.deleter + def proposals(self): + del self._proposals + + @property + def gt_instances(self): + return self._gt_instances + + @gt_instances.setter + def gt_instances(self, value): + self.set_field( + value=value, name='_gt_instances', dtype=BaseDataElement) + + @gt_instances.deleter + def gt_instances(self): + del self._gt_instances + + @property + def pred_instances(self): + return self._pred_instances + + @pred_instances.setter + def pred_instances(self, value): + self.set_field( + value=value, name='_pred_instances', dtype=BaseDataElement) + + @pred_instances.deleter + def pred_instances(self): + del self._pred_instances + + det_sample = DetDataSample() + + # test set + proposals = BaseDataElement(bboxes=torch.rand((5, 4))) + det_sample.proposals = proposals + assert 'proposals' in det_sample + + # test get + assert det_sample.proposals == proposals + + # test delete + del det_sample.proposals + assert 'proposals' not in det_sample + + # test the data whether meet the requirements + with self.assertRaises(AssertionError): + det_sample.proposals = torch.rand((5, 4)) def test_values(self): # test_metainfo_values metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) assert len(instances.metainfo_values()) == len(metainfo.values()) - # test_values - assert len( - instances.values()) == len(metainfo.values()) + len(data.values()) + # test_all_values + assert len(instances.all_values()) == len(metainfo.values()) + len( + data.values()) - # test_data_values - assert len(instances.data_values()) == len(data.values()) + # test_values + assert len(instances.values()) == len(data.values()) def test_keys(self): # test_metainfo_keys metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) assert len(instances.metainfo_keys()) == len(metainfo.keys()) - # test_keys - assert len(instances.keys()) == len(data.keys()) + len(metainfo.keys()) + # test_all_keys + assert len( + instances.all_keys()) == len(data.keys()) + len(metainfo.keys()) - # test_data_keys - assert len(instances.data_keys()) == len(data.keys()) + # test_keys + assert len(instances.keys()) == len(data.keys()) def test_items(self): # test_metainfo_items metainfo, data = self.setup_data() - instances = BaseDataElement(metainfo, data) + instances = BaseDataElement(metainfo=metainfo, **data) assert len(dict(instances.metainfo_items())) == len( dict(metainfo.items())) - # test_items - assert len(dict(instances.items())) == len(dict( + # test_all_items + assert len(dict(instances.all_items())) == len(dict( metainfo.items())) + len(dict(data.items())) - # test_data_items - assert len(dict(instances.data_items())) == len(dict(data.items())) + # test_items + assert len(dict(instances.items())) == len(dict(data.items())) diff --git a/tests/test_data/test_data_sample.py b/tests/test_data/test_data_sample.py deleted file mode 100644 index e6cece55..00000000 --- a/tests/test_data/test_data_sample.py +++ /dev/null @@ -1,479 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import random -from functools import partial -from unittest import TestCase - -import numpy as np -import pytest -import torch - -from mmengine.data import BaseDataElement, BaseDataSample - - -class TestBaseDataSample(TestCase): - - def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) - gt_instances = BaseDataElement( - data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))) - pred_instances = BaseDataElement( - data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))) - data = dict(gt_instances=gt_instances, pred_instances=pred_instances) - return metainfo, data - - def is_equal(self, x, y): - assert type(x) == type(y) - if isinstance( - x, (int, float, str, list, tuple, dict, set, BaseDataElement)): - return x == y - elif isinstance(x, (torch.Tensor, np.ndarray)): - return (x == y).all() - - def check_key_value(self, instances, metainfo=None, data=None): - # check the existence of keys in metainfo, data, and instances - if metainfo: - for k, v in metainfo.items(): - assert k in instances - assert k in instances.keys() - assert k in instances.metainfo_keys() - assert k not in instances.data_keys() - assert self.is_equal(instances.get(k), v) - assert self.is_equal(getattr(instances, k), v) - if data: - for k, v in data.items(): - assert k in instances - assert k in instances.keys() - assert k not in instances.metainfo_keys() - assert k in instances.data_keys() - assert self.is_equal(instances.get(k), v) - assert self.is_equal(getattr(instances, k), v) - - def check_data_device(self, instances, device): - # assert instances.device == device - for v in instances.data_values(): - if isinstance(v, torch.Tensor): - assert v.device == torch.device(device) - elif isinstance(v, (BaseDataSample, BaseDataElement)): - self.check_data_device(v, device) - - def check_data_dtype(self, instances, dtype): - for v in instances.data_values(): - if isinstance(v, (torch.Tensor, np.ndarray)): - assert isinstance(v, dtype) - if isinstance(v, (BaseDataSample, BaseDataElement)): - self.check_data_dtype(v, dtype) - - def check_requires_grad(self, instances): - for v in instances.data_values(): - if isinstance(v, torch.Tensor): - assert v.requires_grad is False - if isinstance(v, (BaseDataSample, BaseDataElement)): - self.check_requires_grad(v) - - def test_init(self): - # initialization with no data and metainfo - metainfo, data = self.setup_data() - instances = BaseDataSample() - for k in metainfo: - assert k not in instances - assert instances.get(k, None) is None - for k in data: - assert k not in instances - assert instances.get(k, 'abc') == 'abc' - - # initialization with kwargs - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo=metainfo, data=data) - self.check_key_value(instances, metainfo, data) - - # initialization with args - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - self.check_key_value(instances, metainfo, data) - - # initialization with args - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo=metainfo) - self.check_key_value(instances, metainfo) - instances = BaseDataSample(data=data) - self.check_key_value(instances, data=data) - - def test_new(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo=metainfo, data=data) - - # test new() with no arguments - new_instances = instances.new() - assert type(new_instances) == type(instances) - # After deepcopy, the address of new data'element will be same as - # origin, but when change new data' element will not effect the origin - # element and will have new address - _, data = self.setup_data() - new_instances.set_data(data) - assert not self.is_equal(new_instances.gt_instances, - instances.gt_instances) - self.check_key_value(new_instances, metainfo, data) - - # test new() with arguments - metainfo, data = self.setup_data() - new_instances = instances.new(metainfo=metainfo, data=data) - assert type(new_instances) == type(instances) - assert id(new_instances.gt_instances) != id(instances.gt_instances) - _, new_data = self.setup_data() - new_instances.set_data(new_data) - assert id(new_instances.gt_instances) != id(data['gt_instances']) - self.check_key_value(new_instances, metainfo, new_data) - - metainfo, data = self.setup_data() - new_instances = instances.new(metainfo=metainfo) - - def test_set_metainfo(self): - metainfo, _ = self.setup_data() - instances = BaseDataSample() - instances.set_metainfo(metainfo) - self.check_key_value(instances, metainfo=metainfo) - - # test setting existing keys and new keys - new_metainfo, _ = self.setup_data() - new_metainfo.update(other=123) - instances.set_metainfo(new_metainfo) - self.check_key_value(instances, metainfo=new_metainfo) - - # test have the same key in data - # TODO - _, data = self.setup_data() - instances = BaseDataSample(data=data) - _, data = self.setup_data() - with self.assertRaises(AttributeError): - instances.set_metainfo(data) - - with self.assertRaises(AssertionError): - instances.set_metainfo(123) - - def test_set_data(self): - metainfo, data = self.setup_data() - instances = BaseDataSample() - - instances.gt_instances = data['gt_instances'] - instances.pred_instances = data['pred_instances'] - self.check_key_value(instances, data=data) - - # a.xx only set data rather than metainfo - instances.img_shape = metainfo['img_shape'] - instances.img_id = metainfo['img_id'] - self.check_key_value(instances, data=metainfo) - - # test can not set metainfo with `.` - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - with self.assertRaises(AttributeError): - instances.img_shape = metainfo['img_shape'] - - # test set '_metainfo_fields' or '_data_fields' - with self.assertRaises(AttributeError): - instances._metainfo_fields = 1 - with self.assertRaises(AttributeError): - instances._data_fields = 1 - - with self.assertRaises(AssertionError): - instances.set_data(123) - - def test_delete_modify(self): - random.seed(10) - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - - new_metainfo, new_data = self.setup_data() - # avoid generating same metainfo - while True: - if new_metainfo['img_id'] == metainfo['img_id'] or new_metainfo[ - 'img_shape'] == metainfo['img_shape']: - new_metainfo, new_data = self.setup_data() - else: - break - - instances.gt_instances = new_data['gt_instances'] - instances.pred_instances = new_data['pred_instances'] - - # a.xx only set data rather than metainfo - instances.set_metainfo(new_metainfo) - self.check_key_value(instances, new_metainfo, new_data) - - assert not self.is_equal(instances.gt_instances, data['gt_instances']) - assert not self.is_equal(instances.pred_instances, - data['pred_instances']) - assert not self.is_equal(instances.img_id, metainfo['img_id']) - assert not self.is_equal(instances.img_shape, metainfo['img_shape']) - - del instances.gt_instances - del instances.img_id - assert not self.is_equal( - instances.pop('pred_instances', None), data['pred_instances']) - with self.assertRaises(AttributeError): - del instances.pred_instances - - assert 'gt_instances' not in instances - assert 'pred_instances' not in instances - assert 'img_id' not in instances - assert instances.pop('gt_instances', None) is None - # test pop not exist key without default - with self.assertRaises(KeyError): - instances.pop('gt_instances') - assert instances.pop('pred_instances', 'abcdef') == 'abcdef' - - assert instances.pop('img_id', None) is None - # test pop not exist key without default - with self.assertRaises(KeyError): - instances.pop('img_id') - assert instances.pop('img_shape') == new_metainfo['img_shape'] - - # test del '_metainfo_fields' or '_data_fields' - with self.assertRaises(AttributeError): - del instances._metainfo_fields - with self.assertRaises(AttributeError): - del instances._data_fields - - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') - def test_cuda(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - - cuda_instances = instances.cuda() - self.check_data_device(cuda_instances, 'cuda:0') - - # here we further test to convert from cuda to cpu - cpu_instances = cuda_instances.cpu() - self.check_data_device(cpu_instances, 'cpu') - del cuda_instances - - cuda_instances = instances.to('cuda:0') - self.check_data_device(cuda_instances, 'cuda:0') - - _, data = self.setup_data() - instances = BaseDataSample(metainfo=data) - - cuda_instances = instances.cuda() - self.check_data_device(cuda_instances, 'cuda:0') - - # here we further test to convert from cuda to cpu - cpu_instances = cuda_instances.cpu() - self.check_data_device(cpu_instances, 'cpu') - del cuda_instances - - cuda_instances = instances.to('cuda:0') - self.check_data_device(cuda_instances, 'cuda:0') - - def test_cpu(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - self.check_data_device(instances, 'cpu') - - cpu_instances = instances.cpu() - # assert cpu_instances.device == 'cpu' - self.check_data_device(cpu_instances, 'cpu') - - _, data = self.setup_data() - instances = BaseDataSample(metainfo=data) - self.check_data_device(instances, 'cpu') - - cpu_instances = instances.cpu() - # assert cpu_instances.device == 'cpu' - self.check_data_device(cpu_instances, 'cpu') - - def test_numpy_tensor(self): - metainfo, data = self.setup_data() - data.update(bboxes=torch.rand((5, 4))) - instances = BaseDataSample(metainfo, data) - np_instances = instances.numpy() - self.check_data_dtype(np_instances, np.ndarray) - - tensor_instances = np_instances.to_tensor() - self.check_data_dtype(tensor_instances, torch.Tensor) - - _, data = self.setup_data() - data.update(bboxes=torch.rand((5, 4))) - instances = BaseDataSample(metainfo=data) - - np_instances = instances.numpy() - self.check_data_dtype(np_instances, np.ndarray) - - tensor_instances = np_instances.to_tensor() - self.check_data_dtype(tensor_instances, torch.Tensor) - - def test_detach(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - instances.detach() - self.check_requires_grad(instances) - - _, data = self.setup_data() - instances = BaseDataSample(metainfo=data) - instances.detach() - self.check_requires_grad(instances) - - def test_repr(self): - metainfo = dict(img_shape=(800, 1196, 3)) - gt_instances = BaseDataElement( - metainfo=metainfo, - data=dict(det_labels=torch.LongTensor([0, 1, 2, 3]))) - - data = dict(gt_instances=gt_instances) - sample = BaseDataSample(metainfo=metainfo, data=data) - address = hex(id(sample)) - address_gt_instances = hex(id(sample.gt_instances)) - assert repr(sample) == (f'<BaseDataSample(' - f'\n META INFORMATION \n' - f'img_shape: (800, 1196, 3) \n' - f'\n DATA FIELDS \n' - f'gt_instances:<BaseDataElement(' - f'\n META INFORMATION \n' - f'img_shape: (800, 1196, 3) \n' - f'\n DATA FIELDS \n' - f'shape of det_labels: torch.Size([4]) \n' - f'\n) at {address_gt_instances}>\n' - f'\n) at {address}>') - - sample = BaseDataSample(data=metainfo, metainfo=data) - address = hex(id(sample)) - address_gt_instances = hex(id(sample.gt_instances)) - assert repr(sample) == (f'<BaseDataSample(' - f'\n META INFORMATION \n' - f'gt_instances:<BaseDataElement(' - f'\n META INFORMATION \n' - f'img_shape: (800, 1196, 3) \n' - f'\n DATA FIELDS \n' - f'shape of det_labels: torch.Size([4]) \n' - f'\n) at {address_gt_instances}>\n' - f'\n DATA FIELDS \n' - f'img_shape: (800, 1196, 3) \n' - f'\n) at {address}>') - metainfo = dict(bboxes=torch.rand((5, 4))) - sample = BaseDataSample(metainfo=metainfo) - address = hex(id(sample)) - assert repr(sample) == (f'<BaseDataSample(' - f'\n META INFORMATION \n' - f'shape of bboxes: torch.Size([5, 4]) \n' - f'\n DATA FIELDS \n' - f'\n) at {address}>') - sample = BaseDataSample(data=metainfo) - address = hex(id(sample)) - assert repr(sample) == (f'<BaseDataSample(' - f'\n META INFORMATION \n' - f'\n DATA FIELDS \n' - f'shape of bboxes: torch.Size([5, 4]) \n' - f'\n) at {address}>') - - def test_set_get_fields(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo) - for key, value in data.items(): - instances.set_field(name=key, value=value, dtype=BaseDataElement) - self.check_key_value(instances, data=data) - - # test type check - _, data = self.setup_data() - instances = BaseDataSample() - for key, value in data.items(): - with self.assertRaises(AssertionError): - instances.set_field( - name=key, value=value, dtype=BaseDataSample) - - def test_del_field(self): - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo) - for key, value in data.items(): - instances.set_field(value=value, name=key, dtype=BaseDataElement) - instances.del_field('gt_instances') - instances.del_field('pred_instances') - - # gt_instance has been deleted, instances does not have the gt_instance - with self.assertRaises(AttributeError): - instances.del_field('gt_instances') - assert 'gt_instances' not in instances - assert 'pred_instances' not in instances - - def test_inheritance(self): - - class DetDataSample(BaseDataSample): - proposals = property( - fget=partial(BaseDataSample.get_field, name='_proposals'), - fset=partial( - BaseDataSample.set_field, - name='_proposals', - dtype=BaseDataElement), - fdel=partial(BaseDataSample.del_field, name='_proposals'), - doc='Region proposals of an image') - gt_instances = property( - fget=partial(BaseDataSample.get_field, name='_gt_instances'), - fset=partial( - BaseDataSample.set_field, - name='_gt_instances', - dtype=BaseDataElement), - fdel=partial(BaseDataSample.del_field, name='_gt_instances'), - doc='Ground truth instances of an image') - pred_instances = property( - fget=partial(BaseDataSample.get_field, name='_pred_instances'), - fset=partial( - BaseDataSample.set_field, - name='_pred_instances', - dtype=BaseDataElement), - fdel=partial(BaseDataSample.del_field, name='_pred_instances'), - doc='Predicted instances of an image') - - det_sample = DetDataSample() - - # test set - proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4)))) - det_sample.proposals = proposals - assert 'proposals' in det_sample - - # test get - assert det_sample.proposals == proposals - - # test delete - del det_sample.proposals - assert 'proposals' not in det_sample - - # test the data whether meet the requirements - with self.assertRaises(AssertionError): - det_sample.proposals = torch.rand((5, 4)) - - def test_values(self): - # test_metainfo_values - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - assert len(instances.metainfo_values()) == len(metainfo.values()) - # test_values - assert len( - instances.values()) == len(metainfo.values()) + len(data.values()) - - # test_data_values - assert len(instances.data_values()) == len(data.values()) - - def test_keys(self): - # test_metainfo_keys - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - assert len(instances.metainfo_keys()) == len(metainfo.keys()) - - # test_keys - assert len(instances.keys()) == len(data.keys()) + len(metainfo.keys()) - - # test_data_keys - assert len(instances.data_keys()) == len(data.keys()) - - def test_items(self): - # test_metainfo_items - metainfo, data = self.setup_data() - instances = BaseDataSample(metainfo, data) - assert len(dict(instances.metainfo_items())) == len( - dict(metainfo.items())) - # test_items - assert len(dict(instances.items())) == len(dict( - metainfo.items())) + len(dict(data.items())) - - # test_data_items - assert len(dict(instances.data_items())) == len(dict(data.items())) diff --git a/tests/test_evaluator/test_base_evaluator.py b/tests/test_evaluator/test_base_evaluator.py index 7040abff..bed31b1f 100644 --- a/tests/test_evaluator/test_base_evaluator.py +++ b/tests/test_evaluator/test_base_evaluator.py @@ -5,7 +5,7 @@ from unittest import TestCase import numpy as np -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value from mmengine.registry import EVALUATORS @@ -66,8 +66,8 @@ class NonPrefixedEvaluator(BaseEvaluator): """Evaluator with unassigned `default_prefix` to test the warning information.""" - def process(self, data_batch: Sequence[Tuple[Any, BaseDataSample]], - predictions: Sequence[BaseDataSample]) -> None: + def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + predictions: Sequence[BaseDataElement]) -> None: pass def compute_metrics(self, results: list) -> dict: @@ -80,9 +80,9 @@ def generate_test_results(size, batch_size, pred, label): for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size data_batch = [(np.zeros( - (3, 10, 10)), BaseDataSample(data={'label': label})) + (3, 10, 10)), BaseDataElement(data={'label': label})) for _ in range(bs)] - predictions = [BaseDataSample(data={'pred': pred}) for _ in range(bs)] + predictions = [BaseDataElement(data={'pred': pred}) for _ in range(bs)] yield (data_batch, predictions) diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py index 81977d44..4d75fedb 100644 --- a/tests/test_hook/test_naive_visualization_hook.py +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -3,7 +3,7 @@ from unittest.mock import Mock import torch -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.hooks import NaiveVisualizationHook @@ -17,7 +17,7 @@ class TestNaiveVisualizationHook: batch_idx = 10 # test with normalize, resize, pad gt_datasamples = [ - BaseDataSample( + BaseDataElement( metainfo=dict( img_norm_cfg=dict( mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), @@ -27,13 +27,13 @@ class TestNaiveVisualizationHook: ori_width=5, img_path='tmp.jpg')) ] - pred_datasamples = [BaseDataSample()] + pred_datasamples = [BaseDataElement()] data_batch = (inputs, gt_datasamples) naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad gt_datasamples = [ - BaseDataSample( + BaseDataElement( metainfo=dict( scale=(10, 10), pad_shape=(15, 15, 3), @@ -41,45 +41,45 @@ class TestNaiveVisualizationHook: ori_width=5, img_path='tmp.jpg')), ] - pred_datasamples = [BaseDataSample()] + pred_datasamples = [BaseDataElement()] data_batch = (inputs, gt_datasamples) naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize gt_datasamples = [ - BaseDataSample( + BaseDataElement( metainfo=dict( scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')), ] - pred_datasamples = [BaseDataSample()] + pred_datasamples = [BaseDataElement()] data_batch = (inputs, gt_datasamples) naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad gt_datasamples = [ - BaseDataSample( + BaseDataElement( metainfo=dict( pad_shape=(15, 15, 3), ori_height=5, ori_width=5, img_path='tmp.jpg')), ] - pred_datasamples = [BaseDataSample()] + pred_datasamples = [BaseDataElement()] data_batch = (inputs, gt_datasamples) naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test no transform gt_datasamples = [ - BaseDataSample( + BaseDataElement( metainfo=dict(ori_height=15, ori_width=15, img_path='tmp.jpg')), ] - pred_datasamples = [BaseDataSample()] + pred_datasamples = [BaseDataElement()] data_batch = (inputs, gt_datasamples) naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index a8d3d4a1..5a7da41b 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -7,7 +7,7 @@ import numpy as np import pytest import torch -from mmengine.data import BaseDataSample +from mmengine.data import BaseDataElement from mmengine.visualization import Visualizer @@ -366,8 +366,8 @@ class TestVisualizer(TestCase): def draw(self, image: Optional[np.ndarray] = None, - gt_sample: Optional['BaseDataSample'] = None, - pred_sample: Optional['BaseDataSample'] = None, + gt_sample: Optional['BaseDataElement'] = None, + pred_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True) -> None: return super().draw(image, gt_sample, pred_sample, draw_gt, -- GitLab