Skip to content
Snippets Groups Projects
Unverified Commit 72cf4109 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Refactor] Refactor interface of checkpointhook (#127)

* [Refactor] Refactor interface of checkpointhook

* fix print format

* minor ifx
parent fff4742e
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta): ...@@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta):
Then it collects all results together from all ranks if distributed Then it collects all results together from all ranks if distributed
training is used. Finally, it computes the metrics of the entire dataset. training is used. Finally, it computes the metrics of the entire dataset.
A subclass of class:`BaseEvaluator` should assign a meanful value to the A subclass of class:`BaseEvaluator` should assign a meaningful value to the
class attribute `default_prefix`. See the argument `prefix` for details. class attribute `default_prefix`. See the argument `prefix` for details.
Args: Args:
......
...@@ -25,6 +25,9 @@ class CheckpointHook(Hook): ...@@ -25,6 +25,9 @@ class CheckpointHook(Hook):
save_optimizer (bool): Whether to save optimizer state_dict in the save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments. checkpoint. It is usually used for resuming experiments.
Default: True. Default: True.
save_param_scheduler (bool): Whether to save param_scheduler state_dict
in the checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional | Path): The root directory to save checkpoints. out_dir (str, optional | Path): The root directory to save checkpoints.
If not specified, ``runner.work_dir`` will be used by default. If If not specified, ``runner.work_dir`` will be used by default. If
specified, the ``out_dir`` will be the concatenation of ``out_dir`` specified, the ``out_dir`` will be the concatenation of ``out_dir``
...@@ -44,6 +47,7 @@ class CheckpointHook(Hook): ...@@ -44,6 +47,7 @@ class CheckpointHook(Hook):
FileClient. See :class:`mmcv.fileio.FileClient` for details. FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None. Default: None.
""" """
out_dir: str
priority = 'VERY_LOW' priority = 'VERY_LOW'
...@@ -51,7 +55,8 @@ class CheckpointHook(Hook): ...@@ -51,7 +55,8 @@ class CheckpointHook(Hook):
interval: int = -1, interval: int = -1,
by_epoch: bool = True, by_epoch: bool = True,
save_optimizer: bool = True, save_optimizer: bool = True,
out_dir: Union[str, Path] = None, save_param_scheduler: bool = True,
out_dir: Optional[Union[str, Path]] = None,
max_keep_ckpts: int = -1, max_keep_ckpts: int = -1,
save_last: bool = True, save_last: bool = True,
sync_buffer: bool = False, sync_buffer: bool = False,
...@@ -60,7 +65,8 @@ class CheckpointHook(Hook): ...@@ -60,7 +65,8 @@ class CheckpointHook(Hook):
self.interval = interval self.interval = interval
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.save_optimizer = save_optimizer self.save_optimizer = save_optimizer
self.out_dir = out_dir self.save_param_scheduler = save_param_scheduler
self.out_dir = out_dir # type: ignore
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last self.save_last = save_last
self.args = kwargs self.args = kwargs
...@@ -121,8 +127,8 @@ class CheckpointHook(Hook): ...@@ -121,8 +127,8 @@ class CheckpointHook(Hook):
# 2. reach the last epoch of training # 2. reach the last epoch of training
if self.every_n_epochs(runner, self.interval) or ( if self.every_n_epochs(runner, self.interval) or (
self.save_last and self.is_last_train_epoch(runner)): self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info(f'Saving checkpoint at \ runner.logger.info(
{runner.epoch + 1} epochs') f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer: if self.sync_buffer:
pass pass
# TODO # TODO
...@@ -135,18 +141,26 @@ class CheckpointHook(Hook): ...@@ -135,18 +141,26 @@ class CheckpointHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self.by_epoch:
ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args) self.out_dir,
filename=ckpt_filename,
save_optimizer=self.save_optimizer,
save_param_scheduler=self.save_param_scheduler,
by_epoch=self.by_epoch,
**self.args)
if runner.meta is not None: if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict()) runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore self.out_dir, ckpt_filename)
# remove other checkpoints # remove other checkpoints
if self.max_keep_ckpts > 0: if self.max_keep_ckpts > 0:
if self.by_epoch: if self.by_epoch:
...@@ -161,7 +175,7 @@ class CheckpointHook(Hook): ...@@ -161,7 +175,7 @@ class CheckpointHook(Hook):
filename_tmpl = self.args.get('filename_tmpl', name) filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts: for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path( ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step)) # type: ignore self.out_dir, filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path): if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path) self.file_client.remove(ckpt_path)
else: else:
...@@ -188,8 +202,8 @@ class CheckpointHook(Hook): ...@@ -188,8 +202,8 @@ class CheckpointHook(Hook):
# 2. reach the last iteration of training # 2. reach the last iteration of training
if self.every_n_iters(runner, self.interval) or \ if self.every_n_iters(runner, self.interval) or \
(self.save_last and self.is_last_iter(runner, mode='train')): (self.save_last and self.is_last_iter(runner, mode='train')):
runner.logger.info(f'Saving checkpoint at \ runner.logger.info(
{runner.iter + 1} iterations') f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer: if self.sync_buffer:
pass pass
# TODO # TODO
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment