From 216521a93659e885d248796ceffbe584b992e2e1 Mon Sep 17 00:00:00 2001
From: Alex Yang <50511903+imabackstabber@users.noreply.github.com>
Date: Wed, 22 Jun 2022 19:48:46 +0800
Subject: [PATCH] [Feat] Support save best ckpt (#310)

* [Feat] Support save best ckpt

* reformat code

* rename function and reformat code

* fix logging info
---
 mmengine/hooks/checkpoint_hook.py       | 206 +++++++++++++++++++++++-
 mmengine/runner/runner.py               |  12 --
 tests/test_hook/test_checkpoint_hook.py | 129 +++++++++++++--
 tests/test_hook/test_logger_hook.py     |   1 -
 tests/test_runner/test_runner.py        |   1 -
 5 files changed, 317 insertions(+), 32 deletions(-)

diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index 334d544f..8a1d2c31 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -1,11 +1,15 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import os.path as osp
+import warnings
+from collections import OrderedDict
+from math import inf
 from pathlib import Path
 from typing import Optional, Sequence, Union
 
 from mmengine.dist import master_only
 from mmengine.fileio import FileClient
 from mmengine.registry import HOOKS
+from mmengine.utils import is_seq_of
 from .hook import Hook
 
 DATA_BATCH = Optional[Sequence[dict]]
@@ -40,6 +44,26 @@ class CheckpointHook(Hook):
             Defaults to -1, which means unlimited.
         save_last (bool): Whether to force the last checkpoint to be
             saved regardless of interval. Defaults to True.
+        save_best (str, optional): If a metric is specified, it would measure
+            the best checkpoint during evaluation. The information about best
+            checkpoint would be saved in ``runner.message_hub`` to keep
+            best score value and best checkpoint path, which will be also
+            loaded when resuming checkpoint. Options are the evaluation metrics
+            on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+            detection and instance segmentation. ``AR@100`` for proposal
+            recall. If ``save_best`` is ``auto``, the first key of the returned
+            ``OrderedDict`` result will be used. Defaults to None.
+        rule (str, optional): Comparison rule for best score. If set to
+            None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+            .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+            be inferred by 'less' rule. Options are 'greater', 'less', None.
+            Defaults to None.
+        greater_keys (List[str], optional): Metric keys that will be
+            inferred by 'greater' comparison rule. If ``None``,
+            _default_greater_keys will be used. Defaults to None.
+        less_keys (List[str], optional): Metric keys that will be
+            inferred by 'less' comparison rule. If ``None``, _default_less_keys
+            will be used. Defaults to None.
         file_client_args (dict, optional): Arguments to instantiate a
             FileClient. See :class:`mmcv.fileio.FileClient` for details.
             Defaults to None.
@@ -48,6 +72,19 @@ class CheckpointHook(Hook):
 
     priority = 'VERY_LOW'
 
+    # logic to save best checkpoints
+    # Since the key for determining greater or less is related to the
+    # downstream tasks, downstream repositories may need to overwrite
+    # the following inner variables accordingly.
+
+    rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+    init_value_map = {'greater': -inf, 'less': inf}
+    _default_greater_keys = [
+        'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+        'mAcc', 'aAcc'
+    ]
+    _default_less_keys = ['loss']
+
     def __init__(self,
                  interval: int = -1,
                  by_epoch: bool = True,
@@ -56,6 +93,10 @@ class CheckpointHook(Hook):
                  out_dir: Optional[Union[str, Path]] = None,
                  max_keep_ckpts: int = -1,
                  save_last: bool = True,
+                 save_best: Optional[str] = None,
+                 rule: Optional[str] = None,
+                 greater_keys: Optional[Sequence[str]] = None,
+                 less_keys: Optional[Sequence[str]] = None,
                  file_client_args: Optional[dict] = None,
                  **kwargs) -> None:
         self.interval = interval
@@ -68,6 +109,32 @@ class CheckpointHook(Hook):
         self.args = kwargs
         self.file_client_args = file_client_args
 
+        # save best logic
+        assert isinstance(save_best, str) or save_best is None, \
+            '"save_best" should be a str or None ' \
+            f'rather than {type(save_best)}'
+        self.save_best = save_best
+
+        if greater_keys is None:
+            self.greater_keys = self._default_greater_keys
+        else:
+            if not isinstance(greater_keys, (list, tuple)):
+                greater_keys = (greater_keys, )  # type: ignore
+            assert is_seq_of(greater_keys, str)
+            self.greater_keys = greater_keys  # type: ignore
+
+        if less_keys is None:
+            self.less_keys = self._default_less_keys
+        else:
+            if not isinstance(less_keys, (list, tuple)):
+                less_keys = (less_keys, )  # type: ignore
+            assert is_seq_of(less_keys, str)
+            self.less_keys = less_keys  # type: ignore
+
+        if self.save_best is not None:
+            self.best_ckpt_path = None
+            self._init_rule(rule, self.save_best)
+
     def before_train(self, runner) -> None:
         """Finish all operations, related to checkpoint.
 
@@ -94,6 +161,12 @@ class CheckpointHook(Hook):
         runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by '
                            f'{self.file_client.name}.')
 
+        if self.save_best is not None:
+            if 'best_ckpt' not in runner.message_hub.runtime_info:
+                self.best_ckpt_path = None
+            else:
+                self.best_ckpt_path = runner.message_hub.get_info('best_ckpt')
+
     def after_train_epoch(self, runner) -> None:
         """Save the checkpoint and synchronize buffers after each epoch.
 
@@ -112,6 +185,27 @@ class CheckpointHook(Hook):
                 f'Saving checkpoint at {runner.epoch + 1} epochs')
             self._save_checkpoint(runner)
 
+    def after_val_epoch(self, runner, metrics):
+        if not self.by_epoch:
+            return
+        self._save_best_checkpoint(runner, metrics)
+
+    def _get_metric_score(self, metrics):
+        eval_res = OrderedDict()
+        if metrics is not None:
+            eval_res.update(metrics)
+
+        if len(eval_res) == 0:
+            warnings.warn(
+                'Since `eval_res` is an empty dict, the behavior to save '
+                'the best checkpoint will be skipped in this evaluation.')
+            return None
+
+        if self.key_indicator == 'auto':
+            self._init_rule(self.rule, list(eval_res.keys())[0])
+
+        return eval_res[self.key_indicator]
+
     @master_only
     def _save_checkpoint(self, runner) -> None:
         """Save the current checkpoint and delete outdated checkpoint.
@@ -135,10 +229,9 @@ class CheckpointHook(Hook):
             by_epoch=self.by_epoch,
             **self.args)
 
-        if runner.meta is not None:
-            runner.meta.setdefault('hook_msgs', dict())
-            runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
-                self.out_dir, ckpt_filename)
+        runner.message_hub.update_info(
+            'last_ckpt', self.file_client.join_path(self.out_dir,
+                                                    ckpt_filename))
 
         # remove other checkpoints
         if self.max_keep_ckpts > 0:
@@ -160,6 +253,111 @@ class CheckpointHook(Hook):
                 else:
                     break
 
+    @master_only
+    def _save_best_checkpoint(self, runner, metrics) -> None:
+        """Save the current checkpoint and delete outdated checkpoint.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        if not self.save_best:
+            return
+
+        if self.by_epoch:
+            ckpt_filename = self.args.get(
+                'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+            cur_type, cur_time = 'epoch', runner.epoch + 1
+        else:
+            ckpt_filename = self.args.get(
+                'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+            cur_type, cur_time = 'iter', runner.iter + 1
+
+        # save best logic
+        # get score from messagehub
+        # notice `_get_metirc_score` helps to infer
+        # self.rule when self.save_best is `auto`
+        key_score = self._get_metric_score(metrics)
+        if 'best_score' not in runner.message_hub.runtime_info:
+            best_score = self.init_value_map[self.rule]
+        else:
+            best_score = runner.message_hub.get_info('best_score')
+
+        if not key_score or not self.is_better_than(key_score, best_score):
+            return
+
+        best_score = key_score
+        runner.message_hub.update_info('best_score', best_score)
+
+        if self.best_ckpt_path and self.file_client.isfile(
+                self.best_ckpt_path):
+            self.file_client.remove(self.best_ckpt_path)
+            runner.logger.info(
+                f'The previous best checkpoint {self.best_ckpt_path} '
+                'is removed')
+
+        best_ckpt_name = f'best_{self.key_indicator}_{ckpt_filename}'
+        self.best_ckpt_path = self.file_client.join_path(  # type: ignore # noqa: E501
+            self.out_dir, best_ckpt_name)
+        runner.message_hub.update_info('best_ckpt', self.best_ckpt_path)
+        runner.save_checkpoint(
+            self.out_dir,
+            filename=best_ckpt_name,
+            file_client_args=self.file_client_args,
+            save_optimizer=False,
+            save_param_scheduler=False,
+            by_epoch=False)
+        runner.logger.info(
+            f'The best checkpoint with {best_score:0.4f} {self.key_indicator} '
+            f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.')
+
+    def _init_rule(self, rule, key_indicator) -> None:
+        """Initialize rule, key_indicator, comparison_func, and best score.
+        Here is the rule to determine which rule is used for key indicator when
+        the rule is not specific (note that the key indicator matching is case-
+        insensitive):
+
+        1. If the key indicator is in ``self.greater_keys``, the rule will be
+        specified as 'greater'.
+        2. Or if the key indicator is in ``self.less_keys``, the rule will be
+        specified as 'less'.
+        3. Or if any one item in ``self.greater_keys`` is a substring of
+            key_indicator , the rule will be specified as 'greater'.
+        4. Or if any one item in ``self.less_keys`` is a substring of
+            key_indicator , the rule will be specified as 'less'.
+        Args:
+            rule (str | None): Comparison rule for best score.
+            key_indicator (str | None): Key indicator to determine the
+                comparison rule.
+        """
+
+        if rule not in self.rule_map and rule is not None:
+            raise KeyError('rule must be greater, less or None, '
+                           f'but got {rule}.')
+
+        if rule is None and key_indicator != 'auto':
+            # `_lc` here means we use the lower case of keys for
+            # case-insensitive matching
+            key_indicator_lc = key_indicator.lower()
+            greater_keys = [key.lower() for key in self.greater_keys]
+            less_keys = [key.lower() for key in self.less_keys]
+
+            if key_indicator_lc in greater_keys:
+                rule = 'greater'
+            elif key_indicator_lc in less_keys:
+                rule = 'less'
+            elif any(key in key_indicator_lc for key in greater_keys):
+                rule = 'greater'
+            elif any(key in key_indicator_lc for key in less_keys):
+                rule = 'less'
+            else:
+                raise ValueError('Cannot infer the rule for key '
+                                 f'{key_indicator}, thus a specific rule '
+                                 'must be specified.')
+        self.rule = rule
+        self.key_indicator = key_indicator
+        if self.rule is not None:
+            self.is_better_than = self.rule_map[self.rule]
+
     def after_train_iter(self,
                          runner,
                          batch_idx: int,
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 87dc7ec3..5f5a033d 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -396,8 +396,6 @@ class Runner:
         # register hooks to `self._hooks`
         self.register_hooks(default_hooks, custom_hooks)
 
-        self.meta: dict = dict()
-
         # dump `cfg` to `work_dir`
         self.dump_config()
 
@@ -1812,13 +1810,6 @@ class Runner:
         self.train_loop._epoch = checkpoint['meta']['epoch']
         self.train_loop._iter = checkpoint['meta']['iter']
 
-        if self.meta is None:
-            self.meta = {}
-
-        self.meta.setdefault('hook_msgs', {})
-        # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
-        self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
-
         # check whether the number of GPU used for current experiment
         # is consistent with resuming from checkpoint
         if 'config' in checkpoint['meta']:
@@ -1959,9 +1950,6 @@ class Runner:
             raise TypeError(
                 f'meta should be a dict or None, but got {type(meta)}')
 
-        if self.meta is not None:
-            meta.update(self.meta)
-
         if by_epoch:
             # self.epoch increments 1 after
             # `self.call_hook('after_train_epoch)` but `save_checkpoint` is
diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py
index b9267fc0..d2dd3d6e 100644
--- a/tests/test_hook/test_checkpoint_hook.py
+++ b/tests/test_hook/test_checkpoint_hook.py
@@ -3,7 +3,10 @@ import os
 import os.path as osp
 from unittest.mock import Mock, patch
 
+import pytest
+
 from mmengine.hooks import CheckpointHook
+from mmengine.logging import MessageHub
 
 
 class MockPetrel:
@@ -46,36 +49,132 @@ class TestCheckpointHook:
         assert checkpoint_hook.out_dir == (
             f'test_dir/{osp.basename(work_dir)}')
 
+    def test_after_val_epoch(self, tmp_path):
+        runner = Mock()
+        runner.work_dir = tmp_path
+        runner.epoch = 9
+        runner.model = Mock()
+        runner.message_hub = MessageHub.get_instance('test_after_val_epoch')
+
+        with pytest.raises(ValueError):
+            # key_indicator must be valid when rule_map is None
+            CheckpointHook(interval=2, by_epoch=True, save_best='unsupport')
+
+        with pytest.raises(KeyError):
+            # rule must be in keys of rule_map
+            CheckpointHook(
+                interval=2, by_epoch=True, save_best='auto', rule='unsupport')
+
+        # if eval_res is an empty dict, print a warning information
+        with pytest.warns(UserWarning) as record_warnings:
+            eval_hook = CheckpointHook(
+                interval=2, by_epoch=True, save_best='auto')
+            eval_hook._get_metric_score(None)
+        # Since there will be many warnings thrown, we just need to check
+        # if the expected exceptions are thrown
+        expected_message = (
+            'Since `eval_res` is an empty dict, the behavior to '
+            'save the best checkpoint will be skipped in this '
+            'evaluation.')
+        for warning in record_warnings:
+            if str(warning.message) == expected_message:
+                break
+        else:
+            assert False
+
+        # if save_best is None,no best_ckpt meta should be stored
+        eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None)
+        eval_hook.before_train(runner)
+        eval_hook.after_val_epoch(runner, None)
+        assert 'best_score' not in runner.message_hub.runtime_info
+        assert 'best_ckpt' not in runner.message_hub.runtime_info
+
+        # when `save_best` is set to `auto`, first metric will be used.
+        metrics = {'acc': 0.5, 'map': 0.3}
+        eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto')
+        eval_hook.before_train(runner)
+        eval_hook.after_val_epoch(runner, metrics)
+        best_ckpt_name = 'best_acc_epoch_10.pth'
+        best_ckpt_path = eval_hook.file_client.join_path(
+            eval_hook.out_dir, best_ckpt_name)
+        assert eval_hook.key_indicator == 'acc'
+        assert eval_hook.rule == 'greater'
+        assert 'best_score' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_score') == 0.5
+        assert 'best_ckpt' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_ckpt') == best_ckpt_path
+
+        # # when `save_best` is set to `acc`, it should update greater value
+        eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc')
+        eval_hook.before_train(runner)
+        metrics['acc'] = 0.8
+        eval_hook.after_val_epoch(runner, metrics)
+        assert 'best_score' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_score') == 0.8
+
+        # # when `save_best` is set to `loss`, it should update less value
+        eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss')
+        eval_hook.before_train(runner)
+        metrics['loss'] = 0.8
+        eval_hook.after_val_epoch(runner, metrics)
+        metrics['loss'] = 0.5
+        eval_hook.after_val_epoch(runner, metrics)
+        assert 'best_score' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_score') == 0.5
+
+        # when `rule` is set to `less`,then it should update less value
+        # no matter what `save_best` is
+        eval_hook = CheckpointHook(
+            interval=2, by_epoch=True, save_best='acc', rule='less')
+        eval_hook.before_train(runner)
+        metrics['acc'] = 0.3
+        eval_hook.after_val_epoch(runner, metrics)
+        assert 'best_score' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_score') == 0.3
+
+        # # when `rule` is set to `greater`,then it should update greater value
+        # # no matter what `save_best` is
+        eval_hook = CheckpointHook(
+            interval=2, by_epoch=True, save_best='loss', rule='greater')
+        eval_hook.before_train(runner)
+        metrics['loss'] = 1.0
+        eval_hook.after_val_epoch(runner, metrics)
+        assert 'best_score' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('best_score') == 1.0
+
     def test_after_train_epoch(self, tmp_path):
         runner = Mock()
         work_dir = str(tmp_path)
         runner.work_dir = tmp_path
         runner.epoch = 9
-        runner.meta = dict()
         runner.model = Mock()
+        runner.message_hub = MessageHub.get_instance('test_after_train_epoch')
 
         # by epoch is True
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
         checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_epoch(runner)
         assert (runner.epoch + 1) % 2 == 0
-        assert runner.meta['hook_msgs']['last_ckpt'] == (
-            f'{work_dir}/epoch_10.pth')
+        assert 'last_ckpt' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('last_ckpt') == (
+                f'{work_dir}/epoch_10.pth')
+
         # epoch can not be evenly divided by 2
         runner.epoch = 10
         checkpoint_hook.after_train_epoch(runner)
-        assert runner.meta['hook_msgs']['last_ckpt'] == (
-            f'{work_dir}/epoch_10.pth')
+        assert 'last_ckpt' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('last_ckpt') == (
+                f'{work_dir}/epoch_10.pth')
 
         # by epoch is False
         runner.epoch = 9
-        runner.meta = dict()
+        runner.message_hub = MessageHub.get_instance('test_after_train_epoch1')
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
         checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_epoch(runner)
-        assert runner.meta.get('hook_msgs', None) is None
+        assert 'last_ckpt' not in runner.message_hub.runtime_info
 
-        # max_keep_ckpts > 0
+        # # max_keep_ckpts > 0
         runner.work_dir = work_dir
         os.system(f'touch {work_dir}/epoch_8.pth')
         checkpoint_hook = CheckpointHook(
@@ -91,28 +190,30 @@ class TestCheckpointHook:
         runner.work_dir = str(work_dir)
         runner.iter = 9
         batch_idx = 9
-        runner.meta = dict()
         runner.model = Mock()
+        runner.message_hub = MessageHub.get_instance('test_after_train_iter')
 
         # by epoch is True
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
         checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
-        assert runner.meta.get('hook_msgs', None) is None
+        assert 'last_ckpt' not in runner.message_hub.runtime_info
 
         # by epoch is False
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
         checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
         assert (runner.iter + 1) % 2 == 0
-        assert runner.meta['hook_msgs']['last_ckpt'] == (
-            f'{work_dir}/iter_10.pth')
+        assert 'last_ckpt' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('last_ckpt') == (
+                f'{work_dir}/iter_10.pth')
 
         # epoch can not be evenly divided by 2
         runner.iter = 10
         checkpoint_hook.after_train_epoch(runner)
-        assert runner.meta['hook_msgs']['last_ckpt'] == (
-            f'{work_dir}/iter_10.pth')
+        assert 'last_ckpt' in runner.message_hub.runtime_info and \
+            runner.message_hub.get_info('last_ckpt') == (
+                f'{work_dir}/iter_10.pth')
 
         # max_keep_ckpts > 0
         runner.iter = 9
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 7c5bd56b..50fd4562 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -103,7 +103,6 @@ class TestLoggerHook:
         runner = MagicMock()
         runner.log_processor.get_log_after_iter = MagicMock(
             return_value=(dict(), 'log_str'))
-        runner.meta = dict(exp_name='retinanet')
         runner.logger = MagicMock()
         logger_hook = LoggerHook()
         logger_hook.after_train_iter(runner, batch_idx=999)
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index c46196de..b14bf403 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -1682,7 +1682,6 @@ class TestRunner(TestCase):
         ckpt = torch.load(path)
         self.assertEqual(ckpt['meta']['epoch'], 0)
         self.assertEqual(ckpt['meta']['iter'], 12)
-        # self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
         assert isinstance(ckpt['optimizer'], dict)
         assert isinstance(ckpt['param_schedulers'], list)
         self.assertIsInstance(ckpt['message_hub'], MessageHub)
-- 
GitLab