From df4e6e32940aca36b4c14ac534ee0911b2d18a61 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Tue, 2 Aug 2022 19:06:35 +0800
Subject: [PATCH] [Fix] Fix resume `message_hub` and save `metainfo` in
 message_hub. (#394)

* Fix resume message hub and save metainfo in messagehub

* fix as comment
---
 mmengine/hooks/runtime_info_hook.py    | 14 ++++-
 mmengine/logging/message_hub.py        | 82 ++++++++++++++++++++++++--
 mmengine/runner/loops.py               |  1 -
 tests/test_logging/test_message_hub.py | 42 ++++++++++++-
 4 files changed, 128 insertions(+), 11 deletions(-)

diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py
index 6098d540..bfce884b 100644
--- a/mmengine/hooks/runtime_info_hook.py
+++ b/mmengine/hooks/runtime_info_hook.py
@@ -1,7 +1,8 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from typing import Dict, Optional, Sequence
 
-from mmengine.registry import HOOKS
+from ..registry import HOOKS
+from ..utils import get_git_hash
 from .hook import Hook
 
 DATA_BATCH = Optional[Sequence[dict]]
@@ -18,12 +19,23 @@ class RuntimeInfoHook(Hook):
 
     priority = 'VERY_HIGH'
 
+    def before_run(self, runner) -> None:
+        import mmengine
+        metainfo = dict(
+            cfg=runner.cfg.pretty_text,
+            seed=runner.seed,
+            experiment_name=runner.experiment_name,
+            mmengine_version=mmengine.__version__ + get_git_hash())
+        runner.message_hub.update_info_dict(metainfo)
+
     def before_train(self, runner) -> None:
         """Update resumed training state."""
         runner.message_hub.update_info('epoch', runner.epoch)
         runner.message_hub.update_info('iter', runner.iter)
         runner.message_hub.update_info('max_epochs', runner.max_epochs)
         runner.message_hub.update_info('max_iters', runner.max_iters)
+        runner.message_hub.update_info(
+            'dataset_meta', runner.train_dataloader.dataset.metainfo)
 
     def before_train_epoch(self, runner) -> None:
         """Update current epoch information before every epoch."""
diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py
index a3083b10..f36db4a3 100644
--- a/mmengine/logging/message_hub.py
+++ b/mmengine/logging/message_hub.py
@@ -121,7 +121,8 @@ class MessageHub(ManagerMixin):
             keys cannot be modified repeatedly'
 
         Note:
-            resumed cannot be set repeatedly for the same key.
+            The ``resumed`` argument needs to be consistent for the same
+            ``key``.
 
         Args:
             key (str): Key of ``HistoryBuffer``.
@@ -149,6 +150,10 @@ class MessageHub(ManagerMixin):
         be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in
         ``log_dict`` has the same resume option.
 
+        Note:
+            The ``resumed`` argument needs to be consistent for the same
+            ``log_dict``.
+
         Args:
             log_dict (str): Used for batch updating :attr:`_log_scalars`.
             resumed (bool): Whether all ``HistoryBuffer`` referred in
@@ -187,7 +192,8 @@ class MessageHub(ManagerMixin):
         time calling ``update_info``.
 
         Note:
-            resumed cannot be set repeatedly for the same key.
+            The ``resumed`` argument needs to be consistent for the same
+            ``key``.
 
         Examples:
             >>> message_hub = MessageHub()
@@ -203,6 +209,31 @@ class MessageHub(ManagerMixin):
         self._resumed_keys[key] = resumed
         self._runtime_info[key] = value
 
+    def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
+        """Update runtime information with dictionary.
+
+        The key corresponding runtime information will be overwritten each
+        time calling ``update_info``.
+
+        Note:
+            The ``resumed`` argument needs to be consistent for the same
+            ``info_dict``.
+
+        Examples:
+            >>> message_hub = MessageHub()
+            >>> message_hub.update_info({'iter': 100})
+
+        Args:
+            info_dict (str): Runtime information dictionary.
+            resumed (bool): Whether the corresponding ``HistoryBuffer``
+                could be resumed.
+        """
+        assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, '
+                                             f'but got {type(info_dict)}')
+        for key, value in info_dict.items():
+            self._set_resumed_keys(key, resumed)
+            self.update_info(key, value, resumed=resumed)
+
     def _set_resumed_keys(self, key: str, resumed: bool) -> None:
         """Set corresponding resumed keys.
 
@@ -331,7 +362,7 @@ class MessageHub(ManagerMixin):
                         f'just return its reference. ',
                         logger='current',
                         level=logging.WARNING)
-                    saved_scalars[key] = value
+                    saved_info[key] = value
         return dict(
             log_scalars=saved_scalars,
             runtime_info=saved_info,
@@ -359,9 +390,48 @@ class MessageHub(ManagerMixin):
                 assert key in state_dict, (
                     'The loaded `state_dict` of `MessageHub` must contain '
                     f'key: `{key}`')
-            self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
-            self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
-            self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
+            # The old `MessageHub` could save non-HistoryBuffer `log_scalars`,
+            # therefore the loaded `log_scalars` needs to be filtered.
+            for key, value in state_dict['log_scalars'].items():
+                if not isinstance(value, HistoryBuffer):
+                    print_log(
+                        f'{key} in message_hub is not HistoryBuffer, '
+                        f'just skip resuming it.',
+                        logger='current',
+                        level=logging.WARNING)
+                    continue
+                self.log_scalars[key] = value
+
+            for key, value in state_dict['runtime_info'].items():
+                try:
+                    self._runtime_info[key] = copy.deepcopy(value)
+                except:  # noqa: E722
+                    print_log(
+                        f'{key} in message_hub cannot be copied, '
+                        f'just return its reference.',
+                        logger='current',
+                        level=logging.WARNING)
+                    self._runtime_info[key] = value
+
+            for key, value in state_dict['resumed_keys'].items():
+                if key not in set(self.log_scalars.keys()) | \
+                        set(self._runtime_info.keys()):
+                    print_log(
+                        f'resumed key: {key} is not defined in message_hub, '
+                        f'just skip resuming this key.',
+                        logger='current',
+                        level=logging.WARNING)
+                    continue
+                elif not value:
+                    print_log(
+                        f'Although resumed key: {key} is False, {key} '
+                        'will still be loaded this time. This key will '
+                        'not be saved by the next calling of '
+                        '`MessageHub.state_dict()`',
+                        logger='current',
+                        level=logging.WARNING)
+                self._resumed_keys[key] = value
+
         # Since some checkpoints saved serialized `message_hub` instance,
         # `load_state_dict` support loading `message_hub` instance for
         # compatibility
diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index e155263b..75ba3fc5 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -282,7 +282,6 @@ class IterBasedTrainLoop(BaseLoop):
         # outputs should be a dict of loss.
         outputs = self.runner.model.train_step(
             data_batch, optim_wrapper=self.runner.optim_wrapper)
-        self.runner.message_hub.update_info('train_logs', outputs)
 
         self.runner.call_hook(
             'after_train_iter',
diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py
index e0e1180d..2e6d5ddc 100644
--- a/tests/test_logging/test_message_hub.py
+++ b/tests/test_logging/test_message_hub.py
@@ -6,7 +6,13 @@ import numpy as np
 import pytest
 import torch
 
-from mmengine import MessageHub
+from mmengine import HistoryBuffer, MessageHub
+
+
+class NoDeepCopy:
+
+    def __deepcopy__(self, memodict={}):
+        raise NotImplementedError
 
 
 class TestMessageHub:
@@ -45,6 +51,15 @@ class TestMessageHub:
         message_hub.update_info('key', 1)
         assert message_hub.runtime_info['key'] == 1
 
+    def test_update_infos(self):
+        message_hub = MessageHub.get_instance('mmengine')
+        # test runtime value can be overwritten.
+        message_hub.update_info_dict({'a': 2, 'b': 3})
+        assert message_hub.runtime_info['a'] == 2
+        assert message_hub.runtime_info['b'] == 3
+        assert message_hub._resumed_keys['a']
+        assert message_hub._resumed_keys['b']
+
     def test_get_scalar(self):
         message_hub = MessageHub.get_instance('mmengine')
         # Get undefined key will raise error
@@ -102,14 +117,18 @@ class TestMessageHub:
         # update runtime information
         message_hub.update_info('iter', 1, resumed=True)
         message_hub.update_info('tensor', [1, 2, 3], resumed=False)
+        no_copy = NoDeepCopy()
+        message_hub.update_info('no_copy', no_copy, resumed=True)
         state_dict = message_hub.state_dict()
+
         assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
                                                           np.array([1]))
         assert 'lr' not in state_dict['log_scalars']
         assert state_dict['runtime_info']['iter'] == 1
-        assert 'tensor' not in state_dict
+        assert 'tensor' not in state_dict['runtime_info']
+        assert state_dict['runtime_info']['no_copy'] is no_copy
 
-    def test_load_state_dict(self):
+    def test_load_state_dict(self, capsys):
         message_hub1 = MessageHub.get_instance('test_load_state_dict1')
         # update log_scalars.
         message_hub1.update_scalar('loss', 0.1)
@@ -133,6 +152,23 @@ class TestMessageHub:
                                                         np.array([1]))
         assert message_hub3.get_info('iter') == 1
 
+        # Test resume custom state_dict
+        state_dict = OrderedDict()
+        state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer())
+        state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1)
+        state_dict['resumed_keys'] = dict(
+            a=True, b=True, c=True, e=False, f=True)
+
+        message_hub4 = MessageHub.get_instance('test_load_state_dict4')
+        message_hub4.load_state_dict(state_dict)
+        assert 'a' not in message_hub4.log_scalars and 'b' in \
+               message_hub4.log_scalars
+        assert 'c' in message_hub4.runtime_info and \
+               state_dict['runtime_info']['d'] is \
+               message_hub4.runtime_info['d']
+        assert message_hub4._resumed_keys == OrderedDict(
+            b=True, c=True, e=False)
+
     def test_getstate(self):
         message_hub = MessageHub.get_instance('name')
         # update log_scalars.
-- 
GitLab