# Copyright (c) OpenMMLab. All rights reserved. import copy import os import os.path as osp import re import sys from unittest.mock import MagicMock, patch import torch from parameterized import parameterized from mmengine.evaluator import BaseMetric from mmengine.fileio import FileClient, LocalBackend from mmengine.hooks import CheckpointHook from mmengine.logging import MessageHub from mmengine.registry import METRICS from mmengine.testing import RunnerTestCase class TriangleMetric(BaseMetric): default_prefix: str = 'test' def __init__(self, length): super().__init__() self.length = length self.best_idx = length // 2 self.cur_idx = 0 def process(self, *args, **kwargs): self.results.append(0) def compute_metrics(self, *args, **kwargs): self.cur_idx += 1 acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length return dict(acc=acc) class TestCheckpointHook(RunnerTestCase): def setUp(self): super().setUp() METRICS.register_module(module=TriangleMetric, force=True) def tearDown(self): return METRICS.module_dict.clear() def test_init(self): # Test file_client_args and backend_args # TODO: Refactor this test case # with self.assertWarnsRegex( # DeprecationWarning, # '"file_client_args" will be deprecated in future'): # CheckpointHook(file_client_args={'backend': 'disk'}) with self.assertRaisesRegex( ValueError, '"file_client_args" and "backend_args" cannot be set ' 'at the same time'): CheckpointHook( file_client_args={'backend': 'disk'}, backend_args={'backend': 'local'}) # Test save best CheckpointHook(save_best='acc') CheckpointHook(save_best=['acc']) with self.assertRaisesRegex(AssertionError, '"save_best" should be'): CheckpointHook(save_best=dict(acc='acc')) # error when 'auto' in `save_best` list with self.assertRaisesRegex(AssertionError, 'Only support one'): CheckpointHook(interval=2, save_best=['auto', 'acc']) # Test rules CheckpointHook(save_best=['acc', 'mAcc'], rule='greater') with self.assertRaisesRegex(AssertionError, '"rule" should be a str'): CheckpointHook(save_best=['acc'], rule=1) with self.assertRaisesRegex(AssertionError, 'Number of "rule" must be'): CheckpointHook(save_best=['acc'], rule=['greater', 'loss']) # Test greater_keys hook = CheckpointHook(greater_keys='acc') self.assertEqual(hook.greater_keys, ('acc', )) hook = CheckpointHook(greater_keys=['acc']) self.assertEqual(hook.greater_keys, ['acc']) hook = CheckpointHook( interval=2, by_epoch=False, save_best=['acc', 'mIoU']) self.assertEqual(hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(hook.rules, ['greater', 'greater']) # Test less keys hook = CheckpointHook(less_keys='loss_cls') self.assertEqual(hook.less_keys, ('loss_cls', )) hook = CheckpointHook(less_keys=['loss_cls']) self.assertEqual(hook.less_keys, ['loss_cls']) def test_before_train(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) # file_client_args is None checkpoint_hook = CheckpointHook() checkpoint_hook.before_train(runner) self.assertIsInstance(checkpoint_hook.file_client, FileClient) self.assertIsInstance(checkpoint_hook.file_backend, LocalBackend) # file_client_args is not None checkpoint_hook = CheckpointHook(file_client_args={'backend': 'disk'}) checkpoint_hook.before_train(runner) self.assertIsInstance(checkpoint_hook.file_client, FileClient) # file_backend is the alias of file_client self.assertIs(checkpoint_hook.file_backend, checkpoint_hook.file_client) # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.out_dir, runner.work_dir) # the out_dir of the checkpoint hook is not None checkpoint_hook = CheckpointHook( interval=1, by_epoch=True, out_dir='test_dir') checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.out_dir, osp.join('test_dir', osp.basename(cfg.work_dir))) # If `save_best` is a list of string, the path to save the best # checkpoint will be defined in attribute `best_ckpt_path_dict`. checkpoint_hook = CheckpointHook(interval=1, save_best=['acc', 'mIoU']) checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.best_ckpt_path_dict, dict(acc=None, mIoU=None)) self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path')) # Resume 'best_ckpt_path' from message_hub runner.message_hub.update_info('best_ckpt_acc', 'best_acc') checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.best_ckpt_path_dict, dict(acc='best_acc', mIoU=None)) # If `save_best` is a string, the path to save best ckpt will be # defined in attribute `best_ckpt_path` checkpoint_hook = CheckpointHook(interval=1, save_best='acc') checkpoint_hook.before_train(runner) self.assertIsNone(checkpoint_hook.best_ckpt_path) self.assertFalse(hasattr(checkpoint_hook, 'best_ckpt_path_dict')) # Resume `best_ckpt` path from message_hub runner.message_hub.update_info('best_ckpt', 'best_ckpt') checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.best_ckpt_path, 'best_ckpt') def test_after_val_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) runner.train_loop._epoch = 9 # if metrics is an empty dict, print a warning information with self.assertLogs(runner.logger, level='WARNING'): checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='auto') checkpoint_hook.after_val_epoch(runner, {}) # if save_best is None,no best_ckpt meta should be stored checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best=None) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, {}) self.assertNotIn('best_score', runner.message_hub.runtime_info) self.assertNotIn('best_ckpt', runner.message_hub.runtime_info) # when `save_best` is set to `auto`, first metric will be used. metrics = {'acc': 0.5, 'map': 0.3} checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='auto') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_epoch_9.pth' best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) self.assertEqual(checkpoint_hook.key_indicators, ['acc']) self.assertEqual(checkpoint_hook.rules, ['greater']) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) self.assertEqual( runner.message_hub.get_info('best_ckpt'), best_ckpt_path) # # when `save_best` is set to `acc`, it should update greater value checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='acc') checkpoint_hook.before_train(runner) metrics['acc'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 0.8) # # when `save_best` is set to `loss`, it should update less value checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='loss') checkpoint_hook.before_train(runner) metrics['loss'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) metrics['loss'] = 0.5 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='acc', rule='less') checkpoint_hook.before_train(runner) metrics['acc'] = 0.3 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 0.3) # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, save_best='loss', rule='greater') checkpoint_hook.before_train(runner) metrics['loss'] = 1.0 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 1.0) # test multi `save_best` with one rule checkpoint_hook = CheckpointHook( interval=2, save_best=['acc', 'mIoU'], rule='greater') self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) # test multi `save_best` with multi rules checkpoint_hook = CheckpointHook( interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) self.assertEqual(checkpoint_hook.key_indicators, ['FID', 'IS']) self.assertEqual(checkpoint_hook.rules, ['less', 'greater']) # test multi `save_best` with default rule checkpoint_hook = CheckpointHook(interval=2, save_best=['acc', 'mIoU']) self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) runner.message_hub = MessageHub.get_instance( 'test_after_val_epoch_save_multi_best') checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_epoch_9.pth' best_acc_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_epoch_9.pth' best_mIoU_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_mIoU_name) self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) self.assertEqual( runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) self.assertEqual( runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) # test behavior when by_epoch is False cfg = copy.deepcopy(self.iter_based_cfg) runner = self.build_runner(cfg) runner.train_loop._iter = 9 # check best ckpt name and best score metrics = {'acc': 0.5, 'map': 0.3} checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, save_best='acc', rule='greater') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(checkpoint_hook.key_indicators, ['acc']) self.assertEqual(checkpoint_hook.rules, ['greater']) best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) self.assertEqual( runner.message_hub.get_info('best_ckpt'), best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # check best score updating metrics['acc'] = 0.666 checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) self.assertEqual( runner.message_hub.get_info('best_ckpt'), best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.666) # check best checkpoint name with `by_epoch` is False checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, save_best=['acc', 'mIoU']) checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) best_acc_name = 'best_acc_iter_9.pth' best_acc_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_acc_name) best_mIoU_name = 'best_mIoU_iter_9.pth' best_mIoU_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_mIoU_name) self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) self.assertEqual( runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) self.assertEqual( runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) # after_val_epoch should not save last_checkpoint self.assertFalse( osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))) # There should only one best checkpoint be reserved # dist backend for by_epoch, cfg in [(True, self.epoch_based_cfg), (False, self.iter_based_cfg)]: self.clear_work_dir() cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) checkpoint_hook = CheckpointHook( interval=2, by_epoch=by_epoch, save_best='acc') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) all_files = os.listdir(runner.work_dir) best_ckpts = [ file for file in all_files if file.startswith('best') ] self.assertTrue(len(best_ckpts) == 1) # petrel backend # TODO use real petrel oss bucket to test petrel_client = MagicMock() for by_epoch, cfg in [(True, self.epoch_based_cfg), (False, self.iter_based_cfg)]: isfile = MagicMock(return_value=True) self.clear_work_dir() with patch.dict(sys.modules, {'petrel_client': petrel_client}), \ patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \ patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \ patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501 cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) metrics = dict(acc=0.5) petrel_client.client.Client = MagicMock( return_value=petrel_client) checkpoint_hook = CheckpointHook( interval=2, by_epoch=by_epoch, save_best='acc', backend_args=dict(backend='petrel')) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) put_mock.assert_called_once() metrics['acc'] += 0.1 runner.train_loop._epoch += 1 runner.train_loop._iter += 1 checkpoint_hook.after_val_epoch(runner, metrics) isfile.assert_called_once() remove_mock.assert_called_once() def test_after_train_epoch(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) runner.train_loop._epoch = 9 runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) self.assertEqual((runner.epoch + 1) % 2, 0) self.assertEqual( runner.message_hub.get_info('last_ckpt'), osp.join(cfg.work_dir, 'epoch_10.pth')) last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint') self.assertTrue(osp.isfile(last_ckpt_path)) with open(last_ckpt_path) as f: filepath = f.read() self.assertEqual(filepath, osp.join(cfg.work_dir, 'epoch_10.pth')) # epoch can not be evenly divided by 2 runner.train_loop._epoch = 10 checkpoint_hook.after_train_epoch(runner) self.assertEqual( runner.message_hub.get_info('last_ckpt'), osp.join(cfg.work_dir, 'epoch_10.pth')) runner.message_hub.runtime_info.clear() # by epoch is False runner.train_loop._epoch = 9 checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) runner.message_hub.runtime_info.clear() def test_after_train_iter(self): # by epoch is True cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) runner.train_loop._iter = 9 runner.optim_wrapper = runner.build_optim_wrapper(runner.optim_wrapper) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) self.assertNotIn('last_ckpt', runner.message_hub.runtime_info) # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) self.assertIn('last_ckpt', runner.message_hub.runtime_info) self.assertEqual( runner.message_hub.get_info('last_ckpt'), osp.join(cfg.work_dir, 'iter_10.pth')) # epoch can not be evenly divided by 2 runner.train_loop._iter = 10 checkpoint_hook.after_train_epoch(runner) self.assertEqual( runner.message_hub.get_info('last_ckpt'), osp.join(cfg.work_dir, 'iter_10.pth')) @parameterized.expand([['iter'], ['epoch']]) def test_with_runner(self, training_type): common_cfg = getattr(self, f'{training_type}_based_cfg') setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) checkpoint_cfg = dict( type='CheckpointHook', interval=2, by_epoch=training_type == 'epoch') common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) # Test interval in epoch based training cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() for i in range(1, 11): self.assertEqual( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')), i % 2 == 0) # save_last=True self.assertTrue( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) self.clear_work_dir() # Test save_optimizer=False cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False cfg = copy.deepcopy(common_cfg) cfg.param_scheduler = [ dict( type='LinearLR', start_factor=0.1, begin=0, end=500, by_epoch=training_type == 'epoch') ] runner = self.build_runner(cfg) runner.train() ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() # Test out_dir cfg = copy.deepcopy(common_cfg) out_dir = osp.join(self.temp_dir.name, 'out_dir') cfg.default_hooks.checkpoint.out_dir = out_dir runner = self.build_runner(cfg) runner.train() self.assertTrue( osp.isfile( osp.join(out_dir, osp.basename(cfg.work_dir), f'{training_type}_11.pth'))) self.clear_work_dir() # Test max_keep_ckpts cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.max_keep_ckpts = 1 runner = self.build_runner(cfg) runner.train() print(os.listdir(cfg.work_dir)) self.assertTrue( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) for i in range(11): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) self.clear_work_dir() # Test filename_tmpl cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' runner = self.build_runner(cfg) runner.train() self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth'))) self.clear_work_dir() # Test save_best cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.interval = 1 cfg.default_hooks.checkpoint.save_best = 'test/acc' cfg.val_evaluator = dict(type='TriangleMetric', length=11) cfg.train_cfg.val_interval = 1 runner = self.build_runner(cfg) runner.train() best_ckpt = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') self.assertTrue(osp.isfile(best_ckpt)) self.clear_work_dir() # test save published keys cfg = copy.deepcopy(common_cfg) cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] runner = self.build_runner(cfg) runner.train() ckpt_files = os.listdir(runner.work_dir) self.assertTrue( any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files)) self.clear_work_dir()