Skip to content
Snippets Groups Projects
Unverified Commit df4e6e32 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix resume `message_hub` and save `metainfo` in message_hub. (#394)

* Fix resume message hub and save metainfo in messagehub

* fix as comment
parent eb251299
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from mmengine.registry import HOOKS
from ..registry import HOOKS
from ..utils import get_git_hash
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
......@@ -18,12 +19,23 @@ class RuntimeInfoHook(Hook):
priority = 'VERY_HIGH'
def before_run(self, runner) -> None:
import mmengine
metainfo = dict(
cfg=runner.cfg.pretty_text,
seed=runner.seed,
experiment_name=runner.experiment_name,
mmengine_version=mmengine.__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)
def before_train(self, runner) -> None:
"""Update resumed training state."""
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch."""
......
......@@ -121,7 +121,8 @@ class MessageHub(ManagerMixin):
keys cannot be modified repeatedly'
Note:
resumed cannot be set repeatedly for the same key.
The ``resumed`` argument needs to be consistent for the same
``key``.
Args:
key (str): Key of ``HistoryBuffer``.
......@@ -149,6 +150,10 @@ class MessageHub(ManagerMixin):
be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in
``log_dict`` has the same resume option.
Note:
The ``resumed`` argument needs to be consistent for the same
``log_dict``.
Args:
log_dict (str): Used for batch updating :attr:`_log_scalars`.
resumed (bool): Whether all ``HistoryBuffer`` referred in
......@@ -187,7 +192,8 @@ class MessageHub(ManagerMixin):
time calling ``update_info``.
Note:
resumed cannot be set repeatedly for the same key.
The ``resumed`` argument needs to be consistent for the same
``key``.
Examples:
>>> message_hub = MessageHub()
......@@ -203,6 +209,31 @@ class MessageHub(ManagerMixin):
self._resumed_keys[key] = resumed
self._runtime_info[key] = value
def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
"""Update runtime information with dictionary.
The key corresponding runtime information will be overwritten each
time calling ``update_info``.
Note:
The ``resumed`` argument needs to be consistent for the same
``info_dict``.
Examples:
>>> message_hub = MessageHub()
>>> message_hub.update_info({'iter': 100})
Args:
info_dict (str): Runtime information dictionary.
resumed (bool): Whether the corresponding ``HistoryBuffer``
could be resumed.
"""
assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, '
f'but got {type(info_dict)}')
for key, value in info_dict.items():
self._set_resumed_keys(key, resumed)
self.update_info(key, value, resumed=resumed)
def _set_resumed_keys(self, key: str, resumed: bool) -> None:
"""Set corresponding resumed keys.
......@@ -331,7 +362,7 @@ class MessageHub(ManagerMixin):
f'just return its reference. ',
logger='current',
level=logging.WARNING)
saved_scalars[key] = value
saved_info[key] = value
return dict(
log_scalars=saved_scalars,
runtime_info=saved_info,
......@@ -359,9 +390,48 @@ class MessageHub(ManagerMixin):
assert key in state_dict, (
'The loaded `state_dict` of `MessageHub` must contain '
f'key: `{key}`')
self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
# The old `MessageHub` could save non-HistoryBuffer `log_scalars`,
# therefore the loaded `log_scalars` needs to be filtered.
for key, value in state_dict['log_scalars'].items():
if not isinstance(value, HistoryBuffer):
print_log(
f'{key} in message_hub is not HistoryBuffer, '
f'just skip resuming it.',
logger='current',
level=logging.WARNING)
continue
self.log_scalars[key] = value
for key, value in state_dict['runtime_info'].items():
try:
self._runtime_info[key] = copy.deepcopy(value)
except: # noqa: E722
print_log(
f'{key} in message_hub cannot be copied, '
f'just return its reference.',
logger='current',
level=logging.WARNING)
self._runtime_info[key] = value
for key, value in state_dict['resumed_keys'].items():
if key not in set(self.log_scalars.keys()) | \
set(self._runtime_info.keys()):
print_log(
f'resumed key: {key} is not defined in message_hub, '
f'just skip resuming this key.',
logger='current',
level=logging.WARNING)
continue
elif not value:
print_log(
f'Although resumed key: {key} is False, {key} '
'will still be loaded this time. This key will '
'not be saved by the next calling of '
'`MessageHub.state_dict()`',
logger='current',
level=logging.WARNING)
self._resumed_keys[key] = value
# Since some checkpoints saved serialized `message_hub` instance,
# `load_state_dict` support loading `message_hub` instance for
# compatibility
......
......@@ -282,7 +282,6 @@ class IterBasedTrainLoop(BaseLoop):
# outputs should be a dict of loss.
outputs = self.runner.model.train_step(
data_batch, optim_wrapper=self.runner.optim_wrapper)
self.runner.message_hub.update_info('train_logs', outputs)
self.runner.call_hook(
'after_train_iter',
......
......@@ -6,7 +6,13 @@ import numpy as np
import pytest
import torch
from mmengine import MessageHub
from mmengine import HistoryBuffer, MessageHub
class NoDeepCopy:
def __deepcopy__(self, memodict={}):
raise NotImplementedError
class TestMessageHub:
......@@ -45,6 +51,15 @@ class TestMessageHub:
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1
def test_update_infos(self):
message_hub = MessageHub.get_instance('mmengine')
# test runtime value can be overwritten.
message_hub.update_info_dict({'a': 2, 'b': 3})
assert message_hub.runtime_info['a'] == 2
assert message_hub.runtime_info['b'] == 3
assert message_hub._resumed_keys['a']
assert message_hub._resumed_keys['b']
def test_get_scalar(self):
message_hub = MessageHub.get_instance('mmengine')
# Get undefined key will raise error
......@@ -102,14 +117,18 @@ class TestMessageHub:
# update runtime information
message_hub.update_info('iter', 1, resumed=True)
message_hub.update_info('tensor', [1, 2, 3], resumed=False)
no_copy = NoDeepCopy()
message_hub.update_info('no_copy', no_copy, resumed=True)
state_dict = message_hub.state_dict()
assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
np.array([1]))
assert 'lr' not in state_dict['log_scalars']
assert state_dict['runtime_info']['iter'] == 1
assert 'tensor' not in state_dict
assert 'tensor' not in state_dict['runtime_info']
assert state_dict['runtime_info']['no_copy'] is no_copy
def test_load_state_dict(self):
def test_load_state_dict(self, capsys):
message_hub1 = MessageHub.get_instance('test_load_state_dict1')
# update log_scalars.
message_hub1.update_scalar('loss', 0.1)
......@@ -133,6 +152,23 @@ class TestMessageHub:
np.array([1]))
assert message_hub3.get_info('iter') == 1
# Test resume custom state_dict
state_dict = OrderedDict()
state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer())
state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1)
state_dict['resumed_keys'] = dict(
a=True, b=True, c=True, e=False, f=True)
message_hub4 = MessageHub.get_instance('test_load_state_dict4')
message_hub4.load_state_dict(state_dict)
assert 'a' not in message_hub4.log_scalars and 'b' in \
message_hub4.log_scalars
assert 'c' in message_hub4.runtime_info and \
state_dict['runtime_info']['d'] is \
message_hub4.runtime_info['d']
assert message_hub4._resumed_keys == OrderedDict(
b=True, c=True, e=False)
def test_getstate(self):
message_hub = MessageHub.get_instance('name')
# update log_scalars.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment