Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from ..registry import HOOKS
from ..utils import get_git_hash
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
class RuntimeInfoHook(Hook):
"""A hook that updates runtime information into message hub.
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the
training state. Components that cannot access the runner can get runtime
information through the message hub.
"""
priority = 'VERY_HIGH'
def before_run(self, runner) -> None:
import mmengine
metainfo = dict(
cfg=runner.cfg.pretty_text,
seed=runner.seed,
experiment_name=runner.experiment_name,
mmengine_version=mmengine.__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)
def before_train(self, runner) -> None:
"""Update resumed training state."""
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
if hasattr(runner.train_dataloader.dataset, 'metainfo'):
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch."""
runner.message_hub.update_info('epoch', runner.epoch)
def before_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None) -> None:
"""Update current iter and learning rate information before every
iteration."""
runner.message_hub.update_info('iter', runner.iter)
lr_dict = runner.optim_wrapper.get_lr()
assert isinstance(lr_dict, dict), (
'`runner.optim_wrapper.get_lr()` should return a dict '
'of learning rate when training with OptimWrapper(single '
'optimizer) or OptimWrapperDict(multiple optimizer), '
f'but got {type(lr_dict)} please check your optimizer '
'constructor return an `OptimWrapper` or `OptimWrapperDict` '
'instance')
for name, lr in lr_dict.items():
runner.message_hub.update_scalar(f'train/{name}', lr[0])
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ``log_vars`` in model outputs every iteration."""
if outputs is not None:
for key, value in outputs.items():
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
runner.message_hub.update_scalar(f'train/{key}', value)
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
Args:
runner (Runner): The runner of the validation process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'val/{key}', value)
def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
Args:
runner (Runner): The runner of the testing process.
metrics (Dict[str, float], optional): Evaluation results of all
metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results.
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'test/{key}', value)