diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 017784b9a7cdfb49c81e1ae7a97840406ef59060..3ed332bd263315c438e89222d486d0138fa28fc2 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -69,7 +69,7 @@ class CheckpointHook(Hook): self.args = kwargs self.file_client_args = file_client_args - def before_run(self, runner) -> None: + def before_train(self, runner) -> None: """Finish all operations, related to checkpoint. This function will get the appropriate file client, and the directory @@ -78,12 +78,11 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. """ - if not self.out_dir: + if self.out_dir is None: self.out_dir = runner.work_dir self.file_client = FileClient.infer_client(self.file_client_args, self.out_dir) - # if `self.out_dir` is not equal to `runner.work_dir`, it means that # `self.out_dir` is set so the final `self.out_dir` is the # concatenation of `self.out_dir` and the last level directory of @@ -186,8 +185,7 @@ class CheckpointHook(Hook): batch_idx (int): The index of the current batch in the train loop. data_batch (Sequence[dict], optional): Data from dataloader. Defaults to None. - outputs (dict, optional): Outputs from model. - Defaults to None. + outputs (dict, optional): Outputs from model. Defaults to None. """ if self.by_epoch: return diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index 7fabecd57a084ad5f7876d91624df0c9f8a72d37..10f682cd8c6086ba5f417650b24feb8f6e6aa0c7 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -1,13 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -import sys -from tempfile import TemporaryDirectory +import os.path as osp from unittest.mock import Mock, patch from mmengine.hooks import CheckpointHook -sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client'] - class MockPetrel: @@ -30,75 +27,80 @@ prefix_to_backends = {'s3': MockPetrel} class TestCheckpointHook: - @patch('file_client.FileClient._prefix_to_backends', prefix_to_backends) - def test_before_run(self): + @patch('mmengine.fileio.file_client.FileClient._prefix_to_backends', + prefix_to_backends) + def test_before_train(self, tmp_path): runner = Mock() - runner.work_dir = './tmp' + work_dir = str(tmp_path) + runner.work_dir = work_dir # the out_dir of the checkpoint hook is None checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) assert checkpoint_hook.out_dir == runner.work_dir # the out_dir of the checkpoint hook is not None checkpoint_hook = CheckpointHook( interval=1, by_epoch=True, out_dir='test_dir') - checkpoint_hook.before_run(runner) - assert checkpoint_hook.out_dir == 'test_dir/tmp' + checkpoint_hook.before_train(runner) + assert checkpoint_hook.out_dir == ( + f'test_dir/{osp.basename(work_dir)}') # create_symlink in args and create_symlink is True checkpoint_hook = CheckpointHook( interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) assert checkpoint_hook.args['create_symlink'] runner.work_dir = 's3://path/of/file' checkpoint_hook = CheckpointHook( interval=1, by_epoch=True, create_symlink=True) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) assert not checkpoint_hook.args['create_symlink'] - def test_after_train_epoch(self): + def test_after_train_epoch(self, tmp_path): runner = Mock() - runner.work_dir = './tmp' + work_dir = str(tmp_path) + runner.work_dir = tmp_path runner.epoch = 9 runner.meta = dict() runner.model = Mock() # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) assert (runner.epoch + 1) % 2 == 0 - assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth' - + assert runner.meta['hook_msgs']['last_ckpt'] == ( + f'{work_dir}/epoch_10.pth') # epoch can not be evenly divided by 2 runner.epoch = 10 checkpoint_hook.after_train_epoch(runner) - assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth' + assert runner.meta['hook_msgs']['last_ckpt'] == ( + f'{work_dir}/epoch_10.pth') # by epoch is False runner.epoch = 9 runner.meta = dict() checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) assert runner.meta.get('hook_msgs', None) is None # max_keep_ckpts > 0 - with TemporaryDirectory() as tempo_dir: - runner.work_dir = tempo_dir - os.system(f'touch {tempo_dir}/epoch_8.pth') - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, max_keep_ckpts=1) - checkpoint_hook.before_run(runner) - checkpoint_hook.after_train_epoch(runner) - assert (runner.epoch + 1) % 2 == 0 - assert not os.path.exists(f'{tempo_dir}/epoch_8.pth') - - def test_after_train_iter(self): + runner.work_dir = work_dir + os.system(f'touch {work_dir}/epoch_8.pth') + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=True, max_keep_ckpts=1) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_train_epoch(runner) + assert (runner.epoch + 1) % 2 == 0 + assert not os.path.exists(f'{work_dir}/epoch_8.pth') + + def test_after_train_iter(self, tmp_path): + work_dir = str(tmp_path) runner = Mock() - runner.work_dir = './tmp' + runner.work_dir = str(work_dir) runner.iter = 9 batch_idx = 9 runner.meta = dict() @@ -106,29 +108,30 @@ class TestCheckpointHook: # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert runner.meta.get('hook_msgs', None) is None # by epoch is False checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) - checkpoint_hook.before_run(runner) + checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert (runner.iter + 1) % 2 == 0 - assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' + assert runner.meta['hook_msgs']['last_ckpt'] == ( + f'{work_dir}/iter_10.pth') # epoch can not be evenly divided by 2 runner.iter = 10 checkpoint_hook.after_train_epoch(runner) - assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' + assert runner.meta['hook_msgs']['last_ckpt'] == ( + f'{work_dir}/iter_10.pth') # max_keep_ckpts > 0 runner.iter = 9 - with TemporaryDirectory() as tempo_dir: - runner.work_dir = tempo_dir - os.system(f'touch {tempo_dir}/iter_8.pth') - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, max_keep_ckpts=1) - checkpoint_hook.before_run(runner) - checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) - assert not os.path.exists(f'{tempo_dir}/iter_8.pth') + runner.work_dir = work_dir + os.system(f'touch {work_dir}/iter_8.pth') + checkpoint_hook = CheckpointHook( + interval=2, by_epoch=False, max_keep_ckpts=1) + checkpoint_hook.before_train(runner) + checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) + assert not os.path.exists(f'{work_dir}/iter_8.pth')