Skip to content
Snippets Groups Projects
Unverified Commit cf239a2b authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Add checkpoint hook (#66)

* [Feature]: Add checkpoint hook

* [Fix]: Fix lint

* [Fix]: Delete redundant optional and give an example to our_dir

* [Feature]: Add test the last_ckpt in UT

* [Fix]: Fix docstring problem

* [Fix]: Add patch to UT

* [Feature]: Add Test case for by epoch
parent 24483803
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .empty_cache_hook import EmptyCacheHook
from .checkpoint_hook import CheckpointHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .optimizer_hook import OptimizerHook
......@@ -8,5 +9,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'EmptyCacheHook'
'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataSample
from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional | Path): The root directory to save checkpoints.
If not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir``
and the last level directory of ``runner.work_dir``. For example,
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
``./work_dir/cur_exp``, then the ckpt will be saved in
``./tmp/cur_exp``. Deafule to None.
max_keep_ckpts (int): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be
saved regardless of interval. Default: True.
sync_buffer (bool): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
"""
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
save_optimizer: bool = True,
out_dir: Union[str, Path] = None,
max_keep_ckpts: int = -1,
save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None,
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last
self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner: object) -> None:
"""Finish all operations, related to checkpoint.
This function will get the appropriate file client, and the directory
to save these checkpoints of the model.
Args:
runner (object): The runner of the training process.
"""
if not self.out_dir:
self.out_dir = runner.work_dir # type: ignore
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir: # type: ignore
basename = osp.basename(
runner.work_dir.rstrip( # type: ignore
osp.sep))
self.out_dir = self.file_client.join_path(
self.out_dir, # type: ignore
basename)
runner.logger.info(( # type: ignore
f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not
# allow to create a symlink
if 'create_symlink' in self.args:
if self.args[
'create_symlink'] and not self.file_client.allow_symlink:
self.args['create_symlink'] = False
warnings.warn(
('create_symlink is set as True by the user but is changed'
'to be False because creating symbolic link is not '
f'allowed in {self.file_client.name}'))
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner: object) -> None:
"""Save the checkpoint and synchronize buffers after each epoch.
Args:
runner (object): The runner of the training process.
"""
if not self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` epochs
# 2. reach the last epoch of training
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info( # type: ignore
f'Saving checkpoint at \
{runner.epoch + 1} epochs') # type: ignore
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner)
# TODO Add master_only decorator
def _save_checkpoint(self, runner: object) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Args:
runner (object): The runner of the training process.
"""
runner.save_checkpoint( # type: ignore
self.out_dir,
save_optimizer=self.save_optimizer,
**self.args)
if runner.meta is not None: # type: ignore
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl',
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl',
'iter_{}.pth').format(runner.iter + 1) # type: ignore
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
runner.meta['hook_msgs'][ # type: ignore
'last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1 # type: ignore
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1 # type: ignore
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step)) # type: ignore
if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path)
else:
break
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
if self.by_epoch:
return
# save checkpoint for following cases:
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info( # type: ignore
f'Saving checkpoint at \
{runner.iter + 1} iterations') # type: ignore
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import sys
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch
from mmengine.hooks import CheckpointHook
sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
class MockPetrel:
_allow_symlink = False
def __init__(self):
pass
@property
def name(self):
return self.__class__.__name__
@property
def allow_symlink(self):
return self._allow_symlink
prefix_to_backends = {'s3': MockPetrel}
class TestCheckpointHook:
@patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
def test_before_run(self):
runner = Mock()
runner.work_dir = './tmp'
# the out_dir of the checkpoint hook is None
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
checkpoint_hook.before_run(runner)
assert checkpoint_hook.out_dir == runner.work_dir
# the out_dir of the checkpoint hook is not None
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir')
checkpoint_hook.before_run(runner)
assert checkpoint_hook.out_dir == 'test_dir/tmp'
# create_symlink in args and create_symlink is True
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
checkpoint_hook.before_run(runner)
assert checkpoint_hook.args['create_symlink']
runner.work_dir = 's3://path/of/file'
checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, create_symlink=True)
checkpoint_hook.before_run(runner)
assert not checkpoint_hook.args['create_symlink']
def test_after_train_epoch(self):
runner = Mock()
runner.work_dir = './tmp'
runner.epoch = 9
runner.meta = dict()
# by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
# epoch can not be evenly divided by 2
runner.epoch = 10
checkpoint_hook.after_train_epoch(runner)
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
# by epoch is False
runner.epoch = 9
runner.meta = dict()
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_epoch(runner)
assert runner.meta.get('hook_msgs', None) is None
# max_keep_ckpts > 0
with TemporaryDirectory() as tempo_dir:
runner.work_dir = tempo_dir
os.system(f'touch {tempo_dir}/epoch_8.pth')
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=True, max_keep_ckpts=1)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0
assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
def test_after_train_iter(self):
runner = Mock()
runner.work_dir = './tmp'
runner.iter = 9
runner.meta = dict()
# by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
assert runner.meta.get('hook_msgs', None) is None
# by epoch is False
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
assert (runner.iter + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
# epoch can not be evenly divided by 2
runner.iter = 10
checkpoint_hook.after_train_epoch(runner)
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
# max_keep_ckpts > 0
runner.iter = 9
with TemporaryDirectory() as tempo_dir:
runner.work_dir = tempo_dir
os.system(f'touch {tempo_dir}/iter_8.pth')
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=False, max_keep_ckpts=1)
checkpoint_hook.before_run(runner)
checkpoint_hook.after_train_iter(runner)
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment