Skip to content
Snippets Groups Projects
Unverified Commit 5c3b8e45 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Merge Adaptation branch to main branch

parents c3aff4fc 96f3d97f
No related branches found
No related tags found
No related merge requests found
Showing
with 448 additions and 103 deletions
......@@ -63,7 +63,7 @@ jobs:
build_cu102:
machine:
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
image: ubuntu-1604-cuda-10.1:201909-23 # the actual version of cuda is 10.2
resource_class: gpu.nvidia.small
steps:
- checkout
......
# 评测器(Evaluator)
# 评测指标(Metric)和评测器(Evaluator)
在模型验证和模型测试中,通常需要对模型精度做定量评测。在 MMEngine 中实现了[评测器](Todo:evaluator-doc-link)来完成这一功能。评测器可以根据模型的输入数据和预测结果,计算特定的评测指标(Metric)。评测器与数据集之间相互解耦,这使得用户可以任意组合所需的测试数据和评测器。如 [COCOEvaluator](Todo:coco-evaluator-doc-link) 可用于计算 COCO 数据集的 AP,AR 等评测指标,也可用于其他的目标检测数据集上。
在模型验证和模型测试中,通常需要对模型精度做定量评测。在 MMEngine 中实现了[评测指标](Todo:metric-doc-link)[评测器](Todo:evaluator-doc-linek)来完成这一功能。
**评测指标** 根据模型的输入数据和预测结果,完成特定指标下模型精度的计算。评测指标与数据集之间相互解耦,这使得用户可以任意组合所需的测试数据和评测指标。如 [COCOMetric](Todo:coco-metric-doc-link) 可用于计算 COCO 数据集的 AP,AR 等评测指标,也可用于其他的目标检测数据集上。
**评测器** 是评测指标的上层模块,通常包含一个或多个评测指标。评测器的作用是在模型评测时完成必要的数据格式转换,并调用评测指标计算模型精度。评测器通常由[执行器](TODO:runner-doc-link)或测试脚本构建,分别用于在线评测和离线评测。
用户通常不需要深入了解或手动修改评测器,因此该文档将重点介绍评测指标的原理和使用方式。
## 模型精度评测
使用评测器计算模型精度的过程如下图所示。
通常,模型精度评测的过程如下图所示。
测试数据通常会被划分为若干批次(batch)。通过一个循环,依次将每个批次的数据送入模型,得到对应的预测结果,并将预测结果连同模型的输入数据一起通过评测器`process()` 方法送入评测器。当循环结束后,调用评测`evaluate()` 方法,可计算得到对应的评测指标
**在线评测**测试数据通常会被划分为若干批次(batch)。通过一个循环,依次将每个批次的数据送入模型,得到对应的预测结果,并将测试数据和模型预测结果送入评测器。评测器会调用评测指标`process()` 方法对数据和预测结果进行处理。当循环结束后,评测器会调用评测指标`evaluate()` 方法,可计算得到对应指标的模型精度
在实际使用中,这些操作均由任务执行器完成。用户只需要在配置文件中选择要使用的评测器并配置相应参数即可
**离线评测**:与在线评测过程类似,区别是直接读取预先保存的模型预测结果来进行评测。评测器提供了 `offline_evaluate` 接口,用于在离线方式下调用评测指标来计算模型精度。为了避免同时处理大量数据导致内存溢出,离线评测时会将测试数据和预测结果分成若干个块(Chunk)进行处理,类似在线评测中的批次
<div align="center">
<img src="https://user-images.githubusercontent.com/15977946/154652635-f4bda588-9f94-462f-b68f-b900690e6215.png"/>
<img src="https://user-images.githubusercontent.com/15977946/163718224-20a4970a-e540-4a3a-8b01-bf0a604c6841.jpg" width="500"/>
</div>
## 在配置文件中配置评测指标
### 在配置文件中配置评测器
在配置文件中配置评测器时,需要指定评测器的类别、参数以及调用方式等。其中,调用方式通常针对模型验证阶段,包括调用评测器的间隔时间单位(epoch 或 iteration)、间隔时间、主要评测指标(即筛选最佳 checkpoint 所依据的指标)等。
例如,用户希望在模型验证时使用 COCO 评测器,每 10 epoch 评测一次,并以 AP 作为主要评测指标,对应的配置文件部分如下:
```python
validation_cfg=dict(
evaluator=dict(type='COCO'), # 使用 COCO 评测器,无参数
main_metric='AP', # 主要评测指标为 AP
interval=10, # 每 10 epoch 评测一次
by_epoch=True,
)
```
### 使用多个评测器
评测器支持组合使用。用户可以通过配置多个评测器,在模型验证或模型测试阶段同时计算多个评测指标。使用多个评测器时,只需要在配置文件里将所有评测器的配置写在一个列表里即可:
在配置文件中可以通过 `val_evaluator``test_evaluator` 2 个字段分别指定模型验证和测试阶段的评测指标。例如,用户在训练分类模型时,希望在模型验证阶段使用分类正确率和 F1 Score 两个评测指标,可以按以下方式配置:
```python
validation_cfg=dict(
evaluator=[
dict(type='Accuracy', top_k=1), # 使用分类正确率评测器
dict(type='F1Score') # 使用 F1_score 评测器
],
main_metric='accuracy',
interval=10,
by_epoch=True,
)
val_evaluator = [
dict(type='Accuracy', top_k=1), # 使用分类正确率评测指标
dict(type='F1Score') # 使用 F1_score 评测指标
]
```
使用多个评测器时,可能出现评测指标同名的情况。比如,在下面的例子中使用了 2 个 `COCOEvaluator` 分别对检测框和关键点的预测结果进行评测,它们的评测指标都包括 `AP``AR` 等。为了避免同名评测指标引发歧义,`Evaluator` 中支持通过 `prefix` 参数为评测指标名增加前缀。通常,一个 `Evaluator` 会有默认的前缀,用户也可以在配置文件中进行指定。
配置中的`val_evaluator` 会被用于构建一个包含多个评测指标的评测器,其中的每个字典对应于一个评测指标的类别和参数。
如果只使用单个评测指标,也可以省略掉配置中的列表,直接指定评测指标参数。例如,在模型测试阶段使用分类正确率评测指标,对应的配置如下:
```python
validation_cfg=dict(
evaluator=[
dict(type='COCO', iou_type='bbox'), # 使用默认前缀 `COCO`
dict(type='COCO', iou_type='keypoints', prefix='COCOKpts') # 自定义前缀 `COCOKpts`
],
# 指定使用前缀为 COCO 的 AP 为主要评测指标
# 在没有重名指标歧义的情况下,此处可以不写前缀,只写评测指标名
main_metric='COCO/AP',
interval=10,
by_epoch=True,
)
test_evaluator = dict(type='Accuracy', top_k=1)
```
## 增加自定义评测
## 增加自定义评测指标
在 OpenMMLab 的各个算法库中,已经实现了对应方向的常用评测。如 MMDetection 中提供了 COCO 评测,MMClassification 中提供了 Accuracy、F1Score 等评测等。
在 OpenMMLab 的各个算法库中,已经实现了对应方向的常用评测指标。如 MMDetection 中提供了 COCO 评测指标,MMClassification 中提供了 Accuracy、F1Score 等评测指标等。
用户也可以根据自身需求,增加自定义的评测。在实现自定义评测器时,用户需要继承 MMEngine 中提供的评测基类 [BaseEvaluator](Todo:baseevaluator-doc-link),并实现对应的抽象方法。
用户也可以增加自定义的评测指标。在实现自定义评测指标时,需要继承 MMEngine 中提供的评测指标基类 [BaseMetric](Todo:basemetric-doc-link),并实现对应的抽象方法。
### 评测基类
### 评测指标基类
评测基类 `BaseEvaluator` 是一个抽象类,具有以下 2 个抽象方法:
评测指标基类 `BaseMetric` 是一个抽象类,具有以下 2 个抽象方法:
- `process()`: 处理每个批次的测试数据和模型预测结果。处理结果应存放在 `self.results` 列表中,用于在处理完所有测试数据后计算评测指标。
- `compute_metrics()`: 计算评测指标,并将所评测指标存放在一个字典中返回。
其中,`compute_metrics()` 会在 `evaluate()` 方法中被调用;后者在计算评测指标前,会在分布式测试时收集和汇总不同 rank 的中间处理结果。`process()``evaluate()` 都会由任务执行器调用。因此,用户只需要在继承 `BaseEvaluator` 后实现 `process()``compute_metrics()` 方法即可。
其中,`compute_metrics()` 会在 `evaluate()` 方法中被调用;后者在计算评测指标前,会在分布式测试时收集和汇总不同 rank 的中间处理结果。
需要注意的是,`self.results` 中存放的具体类型取决于自定义评测器类的实现。例如,当测试样本或模型输出数据量较大(如语义分割、图像生成等任务),不宜全部存放在内存中时,可以在 `self.results` 中存放每个批次计算得到的指标,并在 `compute_metrics()` 中汇总;或将每个批次的中间结果存储到临时文件中,并在 `self.results` 中存放临时文件路径,最后由 `compute_metrics()` 从文件中读取数据并计算指标。
需要注意的是,`self.results` 中存放的具体类型取决于评测指标子类的实现。例如,当测试样本或模型输出数据量较大(如语义分割、图像生成等任务),不宜全部存放在内存中时,可以在 `self.results` 中存放每个批次计算得到的指标,并在 `compute_metrics()` 中汇总;或将每个批次的中间结果存储到临时文件中,并在 `self.results` 中存放临时文件路径,最后由 `compute_metrics()` 从文件中读取数据并计算指标。
### 自定义评测
### 自定义评测指标
我们以实现分类正确率(Classification Accuracy)评测为例,说明实现自定义评测的方法。
我们以实现分类正确率(Classification Accuracy)评测指标为例,说明自定义评测指标的方法。
首先,自定义评测器类应继承自 `BaseEvaluator`,并应加入注册器 `EVALUATORS` (关于注册器的说明请参考[相关文档](docs\zh_cn\tutorials\registry.md))。
首先,评测指标类应继承自 `BaseMetric`,并应加入注册器 `METRICS` (关于注册器的说明请参考[相关文档](docs\zh_cn\tutorials\registry.md))。
`process()` 方法有 2 个输入参数,分别是一个批次的测试数据样本 `data_batch` 和模型预测结果 `predictions`。我们从中分别取出样本类别标签和分类预测结果,并存放在 `self.results` 中。
`compute_metrics()` 方法有 1 个输入参数 `results`,里面存放了所有批次测试数据经过 `process()` 方法处理后得到的结果。从中取出样本类别标签和分类预测结果,即可计算得到分类正确率 `acc`。最终,将计算得到的评测指标以字典的形式返回。
此外,我们建议在子类中为类属性 `default_prefix` 赋值。如果在初始化参数(即 config 中)没有指定 `prefix`,则会自动使用 `default_prefix` 作为评测指标名的前缀。同时,应在 docstring 中说明该评测`default_prefix` 值以及所有的评测指标
此外,我们建议在子类中为类属性 `default_prefix` 赋值。如果在初始化参数(即 config 中)没有指定 `prefix`,则会自动使用 `default_prefix` 作为评测指标名的前缀。同时,应在 docstring 中说明该评测指标类`default_prefix` 值以及所有的返回指标名称
具体的实现如下:
......@@ -100,7 +75,7 @@ from mmengine.registry import METRICS
import numpy as np
@METRICS.register_module()
@METRICS.register_module() # 将 Accuracy 类注册到 METRICS 注册器
class Accuracy(BaseMetric):
""" Accuracy Evaluator
......@@ -110,9 +85,9 @@ class Accuracy(BaseMetric):
- accuracy (float): classification accuracy
"""
default_prefix = 'ACC'
default_prefix = 'ACC' # 设置 default_prefix
def process(self, data_batch: Sequence[Tuple[Any, dict]],
def process(self, data_batch: data_batch: Sequence[dict],
predictions: Sequence[dict]):
"""Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used
......@@ -128,7 +103,7 @@ class Accuracy(BaseMetric):
# 取出分类预测结果和类别标签
result = {
'pred': predictions['pred_label'],
'gt': data_batch['gt_label']
'gt': data_batch['data_sample']['gt_label']
}
# 将当前 batch 的结果存进 self.results
......
# Copyright (c) OpenMMLab. All rights reserved.
from .base_data_element import BaseDataElement
from .instance_data import InstanceData
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
__all__ = [
'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn',
'pseudo_collate'
'pseudo_collate', 'InstanceData'
]
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import List, Union
import numpy as np
import torch
from .base_data_element import BaseDataElement
IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool]
# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
class InstanceData(BaseDataElement):
"""Data structure for instance-level annnotations or predictions.
Subclass of :class:`BaseDataElement`. All value in `data_fields`
should have the same length. This design refer to
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
Examples:
>>> from mmengine.data import InstanceData
>>> import numpy as np
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
>>> instance_data = InstanceData(metainfo=img_meta)
>>> 'img_shape' in instance_data
True
>>> instance_data.det_labels = torch.LongTensor([2, 3])
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
>>> instance_data.bboxes = torch.rand((2, 4))
>>> len(instance_data)
4
>>> print(instance_data)
<InstanceData(
META INFORMATION
pad_shape: (800, 1196, 3)
img_shape: (800, 1216, 3)
DATA FIELDS
det_labels: tensor([2, 3])
det_scores: tensor([0.8, 0.7000])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
[0.8101, 0.3105, 0.5123, 0.6263]])
) at 0x7fb492de6280>
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
>>> sorted_results.det_scores
tensor([0.7000, 0.8000])
>>> print(instance_data[instance_data.det_scores > 0.75])
<InstanceData(
META INFORMATION
pad_shape: (800, 1216, 3)
img_shape: (800, 1196, 3)
DATA FIELDS
det_labels: tensor([0])
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
det_scores: tensor([0.8000])
) at 0x7fb5cf6e2790>
>>> instance_data[instance_data.det_scores > 0.75].det_labels
tensor([0])
>>> instance_data[instance_data.det_scores > 0.75].det_scores
tensor([0.8000])
"""
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray,
list]):
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(
f'{name} has been used as a '
f'private attribute, which is immutable. ')
else:
assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
f'Can set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray, list)}'
if len(self) > 0:
assert len(value) == len(self), f'the length of ' \
f'values {len(value)} is ' \
f'not consistent with' \
f' the length of this ' \
f':obj:`InstanceData` ' \
f'{len(self)} '
super().__setattr__(name, value)
def __getitem__(self, item: IndexType) -> 'InstanceData':
"""
Args:
item (str, obj:`slice`,
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
get the corresponding values according to item.
Returns:
obj:`InstanceData`: Corresponding values.
"""
assert len(self) > 0, ' This is a empty instance'
assert isinstance(
item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long))
if isinstance(item, str):
return getattr(self, item)
if type(item) == int:
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
new_data = self.new(data={})
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
f'{len(item)} ' \
f' does not match the shape ' \
f'of the indexed tensor ' \
f'in results_filed ' \
f'{len(self)} at ' \
f'first dimension. '
for k, v in self.items():
if isinstance(v, torch.Tensor):
new_data[k] = v[item]
elif isinstance(v, np.ndarray):
new_data[k] = v[item.cpu().numpy()]
elif isinstance(v, list):
r_list = []
# convert to indexes from boolTensor
if isinstance(item,
(torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(-1)
else:
indexes = item
for index in indexes:
r_list.append(v[index])
new_data[k] = r_list
else:
# item is a slice
for k, v in self.items():
new_data[k] = v[item]
return new_data # type:ignore
@staticmethod
def cat(instances_list: List['InstanceData']) -> 'InstanceData':
"""Concat the instances of all :obj:`InstanceData` in the list.
Note: To ensure that cat returns as expected, make sure that
all elements in the list must have exactly the same keys.
Args:
instances_list (list[:obj:`InstanceData`]): A list
of :obj:`InstanceData`.
Returns:
obj:`InstanceData`
"""
assert all(
isinstance(results, InstanceData) for results in instances_list)
assert len(instances_list) > 0
if len(instances_list) == 1:
return instances_list[0]
# metainfo and data_fields must be exactly the
# same for each element to avoid exceptions.
field_keys_list = [
instances.all_keys() for instances in instances_list
]
assert len(set([len(field_keys) for field_keys in field_keys_list])) \
== 1 and len(set(itertools.chain(*field_keys_list))) \
== len(field_keys_list[0]), 'There are different keys in ' \
'`instances_list`, which may ' \
'cause the cat operation ' \
'to fail. Please make sure all ' \
'elements in `instances_list` ' \
'have the exact same key '
new_data = instances_list[0].new(data={})
for k in instances_list[0].keys():
values = [results[k] for results in instances_list]
v0 = values[0]
if isinstance(v0, torch.Tensor):
values = torch.cat(values, dim=0)
elif isinstance(v0, np.ndarray):
values = np.concatenate(values, axis=0)
elif isinstance(v0, list):
values = list(itertools.chain(*values))
else:
raise ValueError(
f'Can not concat the {k} which is a {type(v0)}')
new_data[k] = values
return new_data # type:ignore
def __len__(self) -> int:
if len(self._data_fields) > 0:
return len(self.values()[0])
else:
return 0
......@@ -2,18 +2,13 @@
import itertools
import math
from typing import Iterator, Optional, Sized
# from mmengine.dist import get_dist_info, sync_random_seed
from unittest.mock import MagicMock
import torch
from torch.utils.data import Sampler
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
# TODO, need to remove those lines after implementing dist module
get_dist_info = MagicMock(return_value=(0, 1))
sync_random_seed = MagicMock(return_value=0)
@DATA_SAMPLERS.register_module()
class DefaultSampler(Sampler):
......
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Any, Sequence, Tuple
from typing import Sequence
import numpy as np
import torch
from .base_data_element import BaseDataElement
DATA_BATCH = Sequence[Tuple[Any, BaseDataElement]]
DATA_BATCH = Sequence[dict]
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
......@@ -36,10 +34,10 @@ def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
nothing just returns ``data_batch``.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): Batch of data from
data_batch (Sequence[dict]): Batch of data from
dataloader.
Returns:
Sequence[Tuple[Any, BaseDataElement]]: Return input ``data_batch``.
Sequence[dict]: Return input ``data_batch``.
"""
return data_batch
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from ..registry.root import METRICS
......@@ -37,23 +37,25 @@ class Evaluator:
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
def process(self, data_batch: Sequence[dict],
predictions: Sequence[BaseDataElement]):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
_data_batch = []
for input, data in data_batch:
if isinstance(data, BaseDataElement):
_data_batch.append((input, data.to_dict()))
for data in data_batch:
if isinstance(data['data_sample'], BaseDataElement):
_data_batch.append(
dict(
inputs=data['inputs'],
data_sample=data['data_sample'].to_dict()))
else:
_data_batch.append((input, data))
_data_batch.append(data)
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):
......
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Union
from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
......@@ -50,15 +50,14 @@ class BaseMetric(metaclass=ABCMeta):
self._dataset_meta = dataset_meta
@abstractmethod
def process(self, data_batch: Sequence[Tuple[Any, dict]],
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader.
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""
......
......@@ -2,15 +2,14 @@
import os.path as osp
import warnings
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -185,8 +184,8 @@ class CheckpointHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
import torch
......@@ -7,7 +7,7 @@ from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -46,8 +46,8 @@ class EmptyCacheHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
class Hook:
......@@ -174,8 +174,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
......@@ -190,8 +190,8 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='val')
......@@ -206,8 +206,8 @@ class Hook:
Args:
runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='test')
......@@ -223,8 +223,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
......@@ -247,8 +247,8 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
"""
......@@ -271,8 +271,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
......@@ -317,8 +317,8 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
......@@ -337,8 +337,8 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
......
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -53,8 +53,8 @@ class IterTimerHook(Hook):
runner (Runner): The runner of the training, validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# Update data loading time in `runner.message_hub`.
......@@ -75,8 +75,8 @@ class IterTimerHook(Hook):
runner (Runner): The runner of the training validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
mode (str): Current mode of runner. Defaults to 'train'.
......
......@@ -2,7 +2,7 @@
import os
import os.path as osp
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataElement
from mmengine.fileio import FileClient
......@@ -10,7 +10,7 @@ from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -110,9 +110,6 @@ class LoggerHook(Hook):
f'{runner.timestamp}.log.json')
self.yaml_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
# TODO Compatible with Visualizer.
if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
def after_train_iter(self,
runner,
......@@ -124,8 +121,8 @@ class LoggerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
......@@ -150,7 +147,7 @@ class LoggerHook(Hook):
return
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.writer.add_scalars(tag, step=runner.iter + 1)
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
def after_val_iter(
self,
......@@ -202,7 +199,7 @@ class LoggerHook(Hook):
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.writer.add_scalars(tag, step=runner.iter + 1)
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
def after_test_epoch(self, runner) -> None:
"""Record testing logs after test epoch.
......
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
import cv2
import numpy as np
......@@ -11,6 +11,8 @@ from mmengine.registry import HOOKS
from mmengine.utils.misc import tensor2imgs
# TODO: Due to interface changes, the current class
# functions incorrectly
@HOOKS.register_module()
class NaiveVisualizationHook(Hook):
"""Show or Write the predicted results during the process of testing.
......@@ -41,26 +43,25 @@ class NaiveVisualizationHook(Hook):
self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[Tuple[Any, BaseDataElement]]] = None,
data_batch: Optional[Sequence[dict]] = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
data_batch (Sequence[dict], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
"""
if self.every_n_iters(runner, self._interval):
inputs, data_samples = data_batch # type: ignore
inputs = tensor2imgs(inputs,
**data_samples[0].get('img_norm_cfg', dict()))
for input, data_sample, output in zip(
inputs,
data_samples, # type: ignore
outputs): # type: ignore
for data, output in zip(data_batch, outputs): # type: ignore
input = data['inputs']
data_sample = data['data_sample']
input = tensor2imgs(input,
**data_sample.get('img_norm_cfg',
dict()))[0]
# TODO We will implement a function to revert the augmentation
# in the future.
ori_shape = (data_sample.ori_width, data_sample.ori_height)
......@@ -69,5 +70,6 @@ class NaiveVisualizationHook(Hook):
data_sample.get('scale', ori_shape))
origin_image = cv2.resize(input, ori_shape)
name = osp.basename(data_sample.img_path)
runner.writer.add_image(name, origin_image, data_sample,
output, self.draw_gt, self.draw_pred)
runner.visualizer.add_datasample(name, origin_image,
data_sample, output,
self.draw_gt, self.draw_pred)
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Any, List, Optional, Sequence, Tuple
from typing import List, Optional, Sequence
import torch
from torch.nn.parameter import Parameter
from torch.nn.utils import clip_grad
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -77,10 +76,9 @@ class OptimizerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Optional, Sequence
from mmengine.data import BaseDataElement
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
......@@ -25,10 +24,9 @@ class ParamSchedulerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
data_batch (Sequence[dict], optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
......
......@@ -20,9 +20,17 @@ class DistSamplerSeedHook(Hook):
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.cur_dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr(
runner.train_loop.dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.train_loop.dataloader,
'batch_sampler') and hasattr(
runner.train_loop.dataloader.batch_sampler.sampler,
'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch sampler.
# batch sampler in pytorch warps the sampler as its attributes.
runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
runner.epoch)
......@@ -168,7 +168,7 @@ class LogProcessor:
if mode in ('train', 'val'):
log_items = []
for name, val in log_tag.items():
if mode == 'val' and not name.startswith('loss'):
if mode == 'val' and not name.startswith('val/loss'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Optional, Union
......@@ -229,7 +228,8 @@ class MessageHub(ManagerMixin):
Returns:
OrderedDict: A copy of all runtime information.
"""
return copy.deepcopy(self._runtime_info)
# return copy.deepcopy(self._runtime_info)
return self._runtime_info
def get_scalar(self, key: str) -> HistoryBuffer:
"""Get ``HistoryBuffer`` instance by key.
......@@ -263,7 +263,10 @@ class MessageHub(ManagerMixin):
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return copy.deepcopy(self._runtime_info[key])
# TODO: There are restrictions on objects that can be saved
# return copy.deepcopy(self._runtime_info[key])
return self._runtime_info[key]
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float]) \
......
......@@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel,
from mmengine.registry import MODEL_WRAPPERS
from mmengine.utils import TORCH_VERSION, digit_version
MODEL_WRAPPERS.register_module(module=DataParallel)
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
@MODEL_WRAPPERS.register_module()
class MMDataParallel(DataParallel):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment