diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 334d544fe2d3a10ce49b3a12f653b330e5dcb6dd..8a1d2c31470d6c5d91e6aafcdea6f7dee98ec70b 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -1,11 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +import warnings +from collections import OrderedDict +from math import inf from pathlib import Path from typing import Optional, Sequence, Union from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS +from mmengine.utils import is_seq_of from .hook import Hook DATA_BATCH = Optional[Sequence[dict]] @@ -40,6 +44,26 @@ class CheckpointHook(Hook): Defaults to -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be saved regardless of interval. Defaults to True. + save_best (str, optional): If a metric is specified, it would measure + the best checkpoint during evaluation. The information about best + checkpoint would be saved in ``runner.message_hub`` to keep + best score value and best checkpoint path, which will be also + loaded when resuming checkpoint. Options are the evaluation metrics + on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox + detection and instance segmentation. ``AR@100`` for proposal + recall. If ``save_best`` is ``auto``, the first key of the returned + ``OrderedDict`` result will be used. Defaults to None. + rule (str, optional): Comparison rule for best score. If set to + None, it will infer a reasonable rule. Keys such as 'acc', 'top' + .etc will be inferred by 'greater' rule. Keys contain 'loss' will + be inferred by 'less' rule. Options are 'greater', 'less', None. + Defaults to None. + greater_keys (List[str], optional): Metric keys that will be + inferred by 'greater' comparison rule. If ``None``, + _default_greater_keys will be used. Defaults to None. + less_keys (List[str], optional): Metric keys that will be + inferred by 'less' comparison rule. If ``None``, _default_less_keys + will be used. Defaults to None. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Defaults to None. @@ -48,6 +72,19 @@ class CheckpointHook(Hook): priority = 'VERY_LOW' + # logic to save best checkpoints + # Since the key for determining greater or less is related to the + # downstream tasks, downstream repositories may need to overwrite + # the following inner variables accordingly. + + rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} + init_value_map = {'greater': -inf, 'less': inf} + _default_greater_keys = [ + 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', + 'mAcc', 'aAcc' + ] + _default_less_keys = ['loss'] + def __init__(self, interval: int = -1, by_epoch: bool = True, @@ -56,6 +93,10 @@ class CheckpointHook(Hook): out_dir: Optional[Union[str, Path]] = None, max_keep_ckpts: int = -1, save_last: bool = True, + save_best: Optional[str] = None, + rule: Optional[str] = None, + greater_keys: Optional[Sequence[str]] = None, + less_keys: Optional[Sequence[str]] = None, file_client_args: Optional[dict] = None, **kwargs) -> None: self.interval = interval @@ -68,6 +109,32 @@ class CheckpointHook(Hook): self.args = kwargs self.file_client_args = file_client_args + # save best logic + assert isinstance(save_best, str) or save_best is None, \ + '"save_best" should be a str or None ' \ + f'rather than {type(save_best)}' + self.save_best = save_best + + if greater_keys is None: + self.greater_keys = self._default_greater_keys + else: + if not isinstance(greater_keys, (list, tuple)): + greater_keys = (greater_keys, ) # type: ignore + assert is_seq_of(greater_keys, str) + self.greater_keys = greater_keys # type: ignore + + if less_keys is None: + self.less_keys = self._default_less_keys + else: + if not isinstance(less_keys, (list, tuple)): + less_keys = (less_keys, ) # type: ignore + assert is_seq_of(less_keys, str) + self.less_keys = less_keys # type: ignore + + if self.save_best is not None: + self.best_ckpt_path = None + self._init_rule(rule, self.save_best) + def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. @@ -94,6 +161,12 @@ class CheckpointHook(Hook): runner.logger.info(f'Checkpoints will be saved to {self.out_dir} by ' f'{self.file_client.name}.') + if self.save_best is not None: + if 'best_ckpt' not in runner.message_hub.runtime_info: + self.best_ckpt_path = None + else: + self.best_ckpt_path = runner.message_hub.get_info('best_ckpt') + def after_train_epoch(self, runner) -> None: """Save the checkpoint and synchronize buffers after each epoch. @@ -112,6 +185,27 @@ class CheckpointHook(Hook): f'Saving checkpoint at {runner.epoch + 1} epochs') self._save_checkpoint(runner) + def after_val_epoch(self, runner, metrics): + if not self.by_epoch: + return + self._save_best_checkpoint(runner, metrics) + + def _get_metric_score(self, metrics): + eval_res = OrderedDict() + if metrics is not None: + eval_res.update(metrics) + + if len(eval_res) == 0: + warnings.warn( + 'Since `eval_res` is an empty dict, the behavior to save ' + 'the best checkpoint will be skipped in this evaluation.') + return None + + if self.key_indicator == 'auto': + self._init_rule(self.rule, list(eval_res.keys())[0]) + + return eval_res[self.key_indicator] + @master_only def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. @@ -135,10 +229,9 @@ class CheckpointHook(Hook): by_epoch=self.by_epoch, **self.args) - if runner.meta is not None: - runner.meta.setdefault('hook_msgs', dict()) - runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( - self.out_dir, ckpt_filename) + runner.message_hub.update_info( + 'last_ckpt', self.file_client.join_path(self.out_dir, + ckpt_filename)) # remove other checkpoints if self.max_keep_ckpts > 0: @@ -160,6 +253,111 @@ class CheckpointHook(Hook): else: break + @master_only + def _save_best_checkpoint(self, runner, metrics) -> None: + """Save the current checkpoint and delete outdated checkpoint. + + Args: + runner (Runner): The runner of the training process. + """ + if not self.save_best: + return + + if self.by_epoch: + ckpt_filename = self.args.get( + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + cur_type, cur_time = 'epoch', runner.epoch + 1 + else: + ckpt_filename = self.args.get( + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + cur_type, cur_time = 'iter', runner.iter + 1 + + # save best logic + # get score from messagehub + # notice `_get_metirc_score` helps to infer + # self.rule when self.save_best is `auto` + key_score = self._get_metric_score(metrics) + if 'best_score' not in runner.message_hub.runtime_info: + best_score = self.init_value_map[self.rule] + else: + best_score = runner.message_hub.get_info('best_score') + + if not key_score or not self.is_better_than(key_score, best_score): + return + + best_score = key_score + runner.message_hub.update_info('best_score', best_score) + + if self.best_ckpt_path and self.file_client.isfile( + self.best_ckpt_path): + self.file_client.remove(self.best_ckpt_path) + runner.logger.info( + f'The previous best checkpoint {self.best_ckpt_path} ' + 'is removed') + + best_ckpt_name = f'best_{self.key_indicator}_{ckpt_filename}' + self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 + self.out_dir, best_ckpt_name) + runner.message_hub.update_info('best_ckpt', self.best_ckpt_path) + runner.save_checkpoint( + self.out_dir, + filename=best_ckpt_name, + file_client_args=self.file_client_args, + save_optimizer=False, + save_param_scheduler=False, + by_epoch=False) + runner.logger.info( + f'The best checkpoint with {best_score:0.4f} {self.key_indicator} ' + f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') + + def _init_rule(self, rule, key_indicator) -> None: + """Initialize rule, key_indicator, comparison_func, and best score. + Here is the rule to determine which rule is used for key indicator when + the rule is not specific (note that the key indicator matching is case- + insensitive): + + 1. If the key indicator is in ``self.greater_keys``, the rule will be + specified as 'greater'. + 2. Or if the key indicator is in ``self.less_keys``, the rule will be + specified as 'less'. + 3. Or if any one item in ``self.greater_keys`` is a substring of + key_indicator , the rule will be specified as 'greater'. + 4. Or if any one item in ``self.less_keys`` is a substring of + key_indicator , the rule will be specified as 'less'. + Args: + rule (str | None): Comparison rule for best score. + key_indicator (str | None): Key indicator to determine the + comparison rule. + """ + + if rule not in self.rule_map and rule is not None: + raise KeyError('rule must be greater, less or None, ' + f'but got {rule}.') + + if rule is None and key_indicator != 'auto': + # `_lc` here means we use the lower case of keys for + # case-insensitive matching + key_indicator_lc = key_indicator.lower() + greater_keys = [key.lower() for key in self.greater_keys] + less_keys = [key.lower() for key in self.less_keys] + + if key_indicator_lc in greater_keys: + rule = 'greater' + elif key_indicator_lc in less_keys: + rule = 'less' + elif any(key in key_indicator_lc for key in greater_keys): + rule = 'greater' + elif any(key in key_indicator_lc for key in less_keys): + rule = 'less' + else: + raise ValueError('Cannot infer the rule for key ' + f'{key_indicator}, thus a specific rule ' + 'must be specified.') + self.rule = rule + self.key_indicator = key_indicator + if self.rule is not None: + self.is_better_than = self.rule_map[self.rule] + def after_train_iter(self, runner, batch_idx: int, diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 87dc7ec39421e5133d17a9bfa3d29bcbe0bf489f..5f5a033d35d9282f540cdd6718707398e8322ea9 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -396,8 +396,6 @@ class Runner: # register hooks to `self._hooks` self.register_hooks(default_hooks, custom_hooks) - self.meta: dict = dict() - # dump `cfg` to `work_dir` self.dump_config() @@ -1812,13 +1810,6 @@ class Runner: self.train_loop._epoch = checkpoint['meta']['epoch'] self.train_loop._iter = checkpoint['meta']['iter'] - if self.meta is None: - self.meta = {} - - self.meta.setdefault('hook_msgs', {}) - # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages - self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {})) - # check whether the number of GPU used for current experiment # is consistent with resuming from checkpoint if 'config' in checkpoint['meta']: @@ -1959,9 +1950,6 @@ class Runner: raise TypeError( f'meta should be a dict or None, but got {type(meta)}') - if self.meta is not None: - meta.update(self.meta) - if by_epoch: # self.epoch increments 1 after # `self.call_hook('after_train_epoch)` but `save_checkpoint` is diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index b9267fc0ee6dfd3e8130a565a91f7e3d22e72e13..d2dd3d6ebbcdfd2e12542b10866251dd8ccb20b8 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -3,7 +3,10 @@ import os import os.path as osp from unittest.mock import Mock, patch +import pytest + from mmengine.hooks import CheckpointHook +from mmengine.logging import MessageHub class MockPetrel: @@ -46,36 +49,132 @@ class TestCheckpointHook: assert checkpoint_hook.out_dir == ( f'test_dir/{osp.basename(work_dir)}') + def test_after_val_epoch(self, tmp_path): + runner = Mock() + runner.work_dir = tmp_path + runner.epoch = 9 + runner.model = Mock() + runner.message_hub = MessageHub.get_instance('test_after_val_epoch') + + with pytest.raises(ValueError): + # key_indicator must be valid when rule_map is None + CheckpointHook(interval=2, by_epoch=True, save_best='unsupport') + + with pytest.raises(KeyError): + # rule must be in keys of rule_map + CheckpointHook( + interval=2, by_epoch=True, save_best='auto', rule='unsupport') + + # if eval_res is an empty dict, print a warning information + with pytest.warns(UserWarning) as record_warnings: + eval_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='auto') + eval_hook._get_metric_score(None) + # Since there will be many warnings thrown, we just need to check + # if the expected exceptions are thrown + expected_message = ( + 'Since `eval_res` is an empty dict, the behavior to ' + 'save the best checkpoint will be skipped in this ' + 'evaluation.') + for warning in record_warnings: + if str(warning.message) == expected_message: + break + else: + assert False + + # if save_best is None,no best_ckpt meta should be stored + eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best=None) + eval_hook.before_train(runner) + eval_hook.after_val_epoch(runner, None) + assert 'best_score' not in runner.message_hub.runtime_info + assert 'best_ckpt' not in runner.message_hub.runtime_info + + # when `save_best` is set to `auto`, first metric will be used. + metrics = {'acc': 0.5, 'map': 0.3} + eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='auto') + eval_hook.before_train(runner) + eval_hook.after_val_epoch(runner, metrics) + best_ckpt_name = 'best_acc_epoch_10.pth' + best_ckpt_path = eval_hook.file_client.join_path( + eval_hook.out_dir, best_ckpt_name) + assert eval_hook.key_indicator == 'acc' + assert eval_hook.rule == 'greater' + assert 'best_score' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score') == 0.5 + assert 'best_ckpt' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_ckpt') == best_ckpt_path + + # # when `save_best` is set to `acc`, it should update greater value + eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='acc') + eval_hook.before_train(runner) + metrics['acc'] = 0.8 + eval_hook.after_val_epoch(runner, metrics) + assert 'best_score' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score') == 0.8 + + # # when `save_best` is set to `loss`, it should update less value + eval_hook = CheckpointHook(interval=2, by_epoch=True, save_best='loss') + eval_hook.before_train(runner) + metrics['loss'] = 0.8 + eval_hook.after_val_epoch(runner, metrics) + metrics['loss'] = 0.5 + eval_hook.after_val_epoch(runner, metrics) + assert 'best_score' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score') == 0.5 + + # when `rule` is set to `less`,then it should update less value + # no matter what `save_best` is + eval_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='acc', rule='less') + eval_hook.before_train(runner) + metrics['acc'] = 0.3 + eval_hook.after_val_epoch(runner, metrics) + assert 'best_score' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score') == 0.3 + + # # when `rule` is set to `greater`,then it should update greater value + # # no matter what `save_best` is + eval_hook = CheckpointHook( + interval=2, by_epoch=True, save_best='loss', rule='greater') + eval_hook.before_train(runner) + metrics['loss'] = 1.0 + eval_hook.after_val_epoch(runner, metrics) + assert 'best_score' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('best_score') == 1.0 + def test_after_train_epoch(self, tmp_path): runner = Mock() work_dir = str(tmp_path) runner.work_dir = tmp_path runner.epoch = 9 - runner.meta = dict() runner.model = Mock() + runner.message_hub = MessageHub.get_instance('test_after_train_epoch') # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) assert (runner.epoch + 1) % 2 == 0 - assert runner.meta['hook_msgs']['last_ckpt'] == ( - f'{work_dir}/epoch_10.pth') + assert 'last_ckpt' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('last_ckpt') == ( + f'{work_dir}/epoch_10.pth') + # epoch can not be evenly divided by 2 runner.epoch = 10 checkpoint_hook.after_train_epoch(runner) - assert runner.meta['hook_msgs']['last_ckpt'] == ( - f'{work_dir}/epoch_10.pth') + assert 'last_ckpt' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('last_ckpt') == ( + f'{work_dir}/epoch_10.pth') # by epoch is False runner.epoch = 9 - runner.meta = dict() + runner.message_hub = MessageHub.get_instance('test_after_train_epoch1') checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) - assert runner.meta.get('hook_msgs', None) is None + assert 'last_ckpt' not in runner.message_hub.runtime_info - # max_keep_ckpts > 0 + # # max_keep_ckpts > 0 runner.work_dir = work_dir os.system(f'touch {work_dir}/epoch_8.pth') checkpoint_hook = CheckpointHook( @@ -91,28 +190,30 @@ class TestCheckpointHook: runner.work_dir = str(work_dir) runner.iter = 9 batch_idx = 9 - runner.meta = dict() runner.model = Mock() + runner.message_hub = MessageHub.get_instance('test_after_train_iter') # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) - assert runner.meta.get('hook_msgs', None) is None + assert 'last_ckpt' not in runner.message_hub.runtime_info # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert (runner.iter + 1) % 2 == 0 - assert runner.meta['hook_msgs']['last_ckpt'] == ( - f'{work_dir}/iter_10.pth') + assert 'last_ckpt' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('last_ckpt') == ( + f'{work_dir}/iter_10.pth') # epoch can not be evenly divided by 2 runner.iter = 10 checkpoint_hook.after_train_epoch(runner) - assert runner.meta['hook_msgs']['last_ckpt'] == ( - f'{work_dir}/iter_10.pth') + assert 'last_ckpt' in runner.message_hub.runtime_info and \ + runner.message_hub.get_info('last_ckpt') == ( + f'{work_dir}/iter_10.pth') # max_keep_ckpts > 0 runner.iter = 9 diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 7c5bd56bacdcf9390f7ce5b4cc4734d91b7b73fa..50fd4562c0e60a90f83be6961d9b58f15a55b9a8 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -103,7 +103,6 @@ class TestLoggerHook: runner = MagicMock() runner.log_processor.get_log_after_iter = MagicMock( return_value=(dict(), 'log_str')) - runner.meta = dict(exp_name='retinanet') runner.logger = MagicMock() logger_hook = LoggerHook() logger_hook.after_train_iter(runner, batch_idx=999) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index c46196de4b77ea431daad5ee2eeb5fb7bff0d9f1..b14bf4038611aa0cb9cd7b52cd0144166674ad31 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1682,7 +1682,6 @@ class TestRunner(TestCase): ckpt = torch.load(path) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) - # self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path) assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['param_schedulers'], list) self.assertIsInstance(ckpt['message_hub'], MessageHub)