From df4e6e32940aca36b4c14ac534ee0911b2d18a61 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 2 Aug 2022 19:06:35 +0800 Subject: [PATCH] [Fix] Fix resume `message_hub` and save `metainfo` in message_hub. (#394) * Fix resume message hub and save metainfo in messagehub * fix as comment --- mmengine/hooks/runtime_info_hook.py | 14 ++++- mmengine/logging/message_hub.py | 82 ++++++++++++++++++++++++-- mmengine/runner/loops.py | 1 - tests/test_logging/test_message_hub.py | 42 ++++++++++++- 4 files changed, 128 insertions(+), 11 deletions(-) diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 6098d540..bfce884b 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Sequence -from mmengine.registry import HOOKS +from ..registry import HOOKS +from ..utils import get_git_hash from .hook import Hook DATA_BATCH = Optional[Sequence[dict]] @@ -18,12 +19,23 @@ class RuntimeInfoHook(Hook): priority = 'VERY_HIGH' + def before_run(self, runner) -> None: + import mmengine + metainfo = dict( + cfg=runner.cfg.pretty_text, + seed=runner.seed, + experiment_name=runner.experiment_name, + mmengine_version=mmengine.__version__ + get_git_hash()) + runner.message_hub.update_info_dict(metainfo) + def before_train(self, runner) -> None: """Update resumed training state.""" runner.message_hub.update_info('epoch', runner.epoch) runner.message_hub.update_info('iter', runner.iter) runner.message_hub.update_info('max_epochs', runner.max_epochs) runner.message_hub.update_info('max_iters', runner.max_iters) + runner.message_hub.update_info( + 'dataset_meta', runner.train_dataloader.dataset.metainfo) def before_train_epoch(self, runner) -> None: """Update current epoch information before every epoch.""" diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index a3083b10..f36db4a3 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -121,7 +121,8 @@ class MessageHub(ManagerMixin): keys cannot be modified repeatedly' Note: - resumed cannot be set repeatedly for the same key. + The ``resumed`` argument needs to be consistent for the same + ``key``. Args: key (str): Key of ``HistoryBuffer``. @@ -149,6 +150,10 @@ class MessageHub(ManagerMixin): be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in ``log_dict`` has the same resume option. + Note: + The ``resumed`` argument needs to be consistent for the same + ``log_dict``. + Args: log_dict (str): Used for batch updating :attr:`_log_scalars`. resumed (bool): Whether all ``HistoryBuffer`` referred in @@ -187,7 +192,8 @@ class MessageHub(ManagerMixin): time calling ``update_info``. Note: - resumed cannot be set repeatedly for the same key. + The ``resumed`` argument needs to be consistent for the same + ``key``. Examples: >>> message_hub = MessageHub() @@ -203,6 +209,31 @@ class MessageHub(ManagerMixin): self._resumed_keys[key] = resumed self._runtime_info[key] = value + def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None: + """Update runtime information with dictionary. + + The key corresponding runtime information will be overwritten each + time calling ``update_info``. + + Note: + The ``resumed`` argument needs to be consistent for the same + ``info_dict``. + + Examples: + >>> message_hub = MessageHub() + >>> message_hub.update_info({'iter': 100}) + + Args: + info_dict (str): Runtime information dictionary. + resumed (bool): Whether the corresponding ``HistoryBuffer`` + could be resumed. + """ + assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, ' + f'but got {type(info_dict)}') + for key, value in info_dict.items(): + self._set_resumed_keys(key, resumed) + self.update_info(key, value, resumed=resumed) + def _set_resumed_keys(self, key: str, resumed: bool) -> None: """Set corresponding resumed keys. @@ -331,7 +362,7 @@ class MessageHub(ManagerMixin): f'just return its reference. ', logger='current', level=logging.WARNING) - saved_scalars[key] = value + saved_info[key] = value return dict( log_scalars=saved_scalars, runtime_info=saved_info, @@ -359,9 +390,48 @@ class MessageHub(ManagerMixin): assert key in state_dict, ( 'The loaded `state_dict` of `MessageHub` must contain ' f'key: `{key}`') - self._log_scalars = copy.deepcopy(state_dict['log_scalars']) - self._runtime_info = copy.deepcopy(state_dict['runtime_info']) - self._resumed_keys = copy.deepcopy(state_dict['resumed_keys']) + # The old `MessageHub` could save non-HistoryBuffer `log_scalars`, + # therefore the loaded `log_scalars` needs to be filtered. + for key, value in state_dict['log_scalars'].items(): + if not isinstance(value, HistoryBuffer): + print_log( + f'{key} in message_hub is not HistoryBuffer, ' + f'just skip resuming it.', + logger='current', + level=logging.WARNING) + continue + self.log_scalars[key] = value + + for key, value in state_dict['runtime_info'].items(): + try: + self._runtime_info[key] = copy.deepcopy(value) + except: # noqa: E722 + print_log( + f'{key} in message_hub cannot be copied, ' + f'just return its reference.', + logger='current', + level=logging.WARNING) + self._runtime_info[key] = value + + for key, value in state_dict['resumed_keys'].items(): + if key not in set(self.log_scalars.keys()) | \ + set(self._runtime_info.keys()): + print_log( + f'resumed key: {key} is not defined in message_hub, ' + f'just skip resuming this key.', + logger='current', + level=logging.WARNING) + continue + elif not value: + print_log( + f'Although resumed key: {key} is False, {key} ' + 'will still be loaded this time. This key will ' + 'not be saved by the next calling of ' + '`MessageHub.state_dict()`', + logger='current', + level=logging.WARNING) + self._resumed_keys[key] = value + # Since some checkpoints saved serialized `message_hub` instance, # `load_state_dict` support loading `message_hub` instance for # compatibility diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index e155263b..75ba3fc5 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -282,7 +282,6 @@ class IterBasedTrainLoop(BaseLoop): # outputs should be a dict of loss. outputs = self.runner.model.train_step( data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.message_hub.update_info('train_logs', outputs) self.runner.call_hook( 'after_train_iter', diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index e0e1180d..2e6d5ddc 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -6,7 +6,13 @@ import numpy as np import pytest import torch -from mmengine import MessageHub +from mmengine import HistoryBuffer, MessageHub + + +class NoDeepCopy: + + def __deepcopy__(self, memodict={}): + raise NotImplementedError class TestMessageHub: @@ -45,6 +51,15 @@ class TestMessageHub: message_hub.update_info('key', 1) assert message_hub.runtime_info['key'] == 1 + def test_update_infos(self): + message_hub = MessageHub.get_instance('mmengine') + # test runtime value can be overwritten. + message_hub.update_info_dict({'a': 2, 'b': 3}) + assert message_hub.runtime_info['a'] == 2 + assert message_hub.runtime_info['b'] == 3 + assert message_hub._resumed_keys['a'] + assert message_hub._resumed_keys['b'] + def test_get_scalar(self): message_hub = MessageHub.get_instance('mmengine') # Get undefined key will raise error @@ -102,14 +117,18 @@ class TestMessageHub: # update runtime information message_hub.update_info('iter', 1, resumed=True) message_hub.update_info('tensor', [1, 2, 3], resumed=False) + no_copy = NoDeepCopy() + message_hub.update_info('no_copy', no_copy, resumed=True) state_dict = message_hub.state_dict() + assert state_dict['log_scalars']['loss'].data == (np.array([0.1]), np.array([1])) assert 'lr' not in state_dict['log_scalars'] assert state_dict['runtime_info']['iter'] == 1 - assert 'tensor' not in state_dict + assert 'tensor' not in state_dict['runtime_info'] + assert state_dict['runtime_info']['no_copy'] is no_copy - def test_load_state_dict(self): + def test_load_state_dict(self, capsys): message_hub1 = MessageHub.get_instance('test_load_state_dict1') # update log_scalars. message_hub1.update_scalar('loss', 0.1) @@ -133,6 +152,23 @@ class TestMessageHub: np.array([1])) assert message_hub3.get_info('iter') == 1 + # Test resume custom state_dict + state_dict = OrderedDict() + state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer()) + state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1) + state_dict['resumed_keys'] = dict( + a=True, b=True, c=True, e=False, f=True) + + message_hub4 = MessageHub.get_instance('test_load_state_dict4') + message_hub4.load_state_dict(state_dict) + assert 'a' not in message_hub4.log_scalars and 'b' in \ + message_hub4.log_scalars + assert 'c' in message_hub4.runtime_info and \ + state_dict['runtime_info']['d'] is \ + message_hub4.runtime_info['d'] + assert message_hub4._resumed_keys == OrderedDict( + b=True, c=True, e=False) + def test_getstate(self): message_hub = MessageHub.get_instance('name') # update log_scalars. -- GitLab