Skip to content
Snippets Groups Projects
Unverified Commit d65350a9 authored by LeoXing1996's avatar LeoXing1996 Committed by GitHub
Browse files

[Fix] Fix bug of not save-best in iteration-based training (#341)

* fix bug of not save-best in iteration-based training

* revise the unit test
parent 59b0ccfe
No related branches found
No related tags found
No related merge requests found
...@@ -186,8 +186,13 @@ class CheckpointHook(Hook): ...@@ -186,8 +186,13 @@ class CheckpointHook(Hook):
self._save_checkpoint(runner) self._save_checkpoint(runner)
def after_val_epoch(self, runner, metrics): def after_val_epoch(self, runner, metrics):
if not self.by_epoch: """Save the checkpoint and synchronize buffers after each evaluation
return epoch.
Args:
runner (Runner): The runner of the training process.
metrics (dict): Evaluation results of all metrics
"""
self._save_best_checkpoint(runner, metrics) self._save_best_checkpoint(runner, metrics)
def _get_metric_score(self, metrics): def _get_metric_score(self, metrics):
......
...@@ -142,6 +142,41 @@ class TestCheckpointHook: ...@@ -142,6 +142,41 @@ class TestCheckpointHook:
assert 'best_score' in runner.message_hub.runtime_info and \ assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 1.0 runner.message_hub.get_info('best_score') == 1.0
# test behavior when by_epoch is False
runner = Mock()
runner.work_dir = tmp_path
runner.iter = 9
runner.model = Mock()
runner.message_hub = MessageHub.get_instance(
'test_after_val_epoch_by_epoch_is_false')
# check best ckpt name and best score
metrics = {'acc': 0.5, 'map': 0.3}
eval_hook = CheckpointHook(
interval=2, by_epoch=False, save_best='acc', rule='greater')
eval_hook.before_train(runner)
eval_hook.after_val_epoch(runner, metrics)
assert eval_hook.key_indicator == 'acc'
assert eval_hook.rule == 'greater'
best_ckpt_name = 'best_acc_iter_10.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.5
# check best score updating
metrics['acc'] = 0.666
eval_hook.after_val_epoch(runner, metrics)
best_ckpt_name = 'best_acc_iter_10.pth'
best_ckpt_path = eval_hook.file_client.join_path(
eval_hook.out_dir, best_ckpt_name)
assert 'best_ckpt' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_ckpt') == best_ckpt_path
assert 'best_score' in runner.message_hub.runtime_info and \
runner.message_hub.get_info('best_score') == 0.666
def test_after_train_epoch(self, tmp_path): def test_after_train_epoch(self, tmp_path):
runner = Mock() runner = Mock()
work_dir = str(tmp_path) work_dir = str(tmp_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment