From 72cf410969ae486960aad348f620d395a306466d Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Sun, 13 Mar 2022 23:39:28 +0800 Subject: [PATCH] [Refactor] Refactor interface of checkpointhook (#127) * [Refactor] Refactor interface of checkpointhook * fix print format * minor ifx --- mmengine/evaluator/base.py | 2 +- mmengine/hooks/checkpoint_hook.py | 44 ++++++++++++++++++++----------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/base.py index bc90634e..5354df90 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/base.py @@ -16,7 +16,7 @@ class BaseEvaluator(metaclass=ABCMeta): Then it collects all results together from all ranks if distributed training is used. Finally, it computes the metrics of the entire dataset. - A subclass of class:`BaseEvaluator` should assign a meanful value to the + A subclass of class:`BaseEvaluator` should assign a meaningful value to the class attribute `default_prefix`. See the argument `prefix` for details. Args: diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 313900b7..71929c0c 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -25,6 +25,9 @@ class CheckpointHook(Hook): save_optimizer (bool): Whether to save optimizer state_dict in the checkpoint. It is usually used for resuming experiments. Default: True. + save_param_scheduler (bool): Whether to save param_scheduler state_dict + in the checkpoint. It is usually used for resuming experiments. + Default: True. out_dir (str, optional | Path): The root directory to save checkpoints. If not specified, ``runner.work_dir`` will be used by default. If specified, the ``out_dir`` will be the concatenation of ``out_dir`` @@ -44,6 +47,7 @@ class CheckpointHook(Hook): FileClient. See :class:`mmcv.fileio.FileClient` for details. Default: None. """ + out_dir: str priority = 'VERY_LOW' @@ -51,7 +55,8 @@ class CheckpointHook(Hook): interval: int = -1, by_epoch: bool = True, save_optimizer: bool = True, - out_dir: Union[str, Path] = None, + save_param_scheduler: bool = True, + out_dir: Optional[Union[str, Path]] = None, max_keep_ckpts: int = -1, save_last: bool = True, sync_buffer: bool = False, @@ -60,7 +65,8 @@ class CheckpointHook(Hook): self.interval = interval self.by_epoch = by_epoch self.save_optimizer = save_optimizer - self.out_dir = out_dir + self.save_param_scheduler = save_param_scheduler + self.out_dir = out_dir # type: ignore self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs @@ -121,8 +127,8 @@ class CheckpointHook(Hook): # 2. reach the last epoch of training if self.every_n_epochs(runner, self.interval) or ( self.save_last and self.is_last_train_epoch(runner)): - runner.logger.info(f'Saving checkpoint at \ - {runner.epoch + 1} epochs') + runner.logger.info( + f'Saving checkpoint at {runner.epoch + 1} epochs') if self.sync_buffer: pass # TODO @@ -135,18 +141,26 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. """ + if self.by_epoch: + ckpt_filename = self.args.get( + 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) + else: + ckpt_filename = self.args.get( + 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) + runner.save_checkpoint( - self.out_dir, save_optimizer=self.save_optimizer, **self.args) + self.out_dir, + filename=ckpt_filename, + save_optimizer=self.save_optimizer, + save_param_scheduler=self.save_param_scheduler, + by_epoch=self.by_epoch, + **self.args) + if runner.meta is not None: - if self.by_epoch: - cur_ckpt_filename = self.args.get( - 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) - else: - cur_ckpt_filename = self.args.get( - 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) runner.meta.setdefault('hook_msgs', dict()) runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( - self.out_dir, cur_ckpt_filename) # type: ignore + self.out_dir, ckpt_filename) + # remove other checkpoints if self.max_keep_ckpts > 0: if self.by_epoch: @@ -161,7 +175,7 @@ class CheckpointHook(Hook): filename_tmpl = self.args.get('filename_tmpl', name) for _step in redundant_ckpts: ckpt_path = self.file_client.join_path( - self.out_dir, filename_tmpl.format(_step)) # type: ignore + self.out_dir, filename_tmpl.format(_step)) if self.file_client.isfile(ckpt_path): self.file_client.remove(ckpt_path) else: @@ -188,8 +202,8 @@ class CheckpointHook(Hook): # 2. reach the last iteration of training if self.every_n_iters(runner, self.interval) or \ (self.save_last and self.is_last_iter(runner, mode='train')): - runner.logger.info(f'Saving checkpoint at \ - {runner.iter + 1} iterations') + runner.logger.info( + f'Saving checkpoint at {runner.iter + 1} iterations') if self.sync_buffer: pass # TODO -- GitLab