From cf239a2b17a3bc107ba385bae71453482858d0e8 Mon Sep 17 00:00:00 2001
From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Date: Wed, 2 Mar 2022 22:01:58 +0800
Subject: [PATCH] [Feature]: Add checkpoint hook (#66)

* [Feature]: Add checkpoint hook

* [Fix]: Fix lint

* [Fix]: Delete redundant optional and give an example to our_dir

* [Feature]: Add test the last_ckpt in UT

* [Fix]: Fix docstring problem

* [Fix]: Add patch to UT

* [Feature]: Add Test case for by epoch
---
 mmengine/hooks/__init__.py              |   3 +-
 mmengine/hooks/checkpoint_hook.py       | 206 ++++++++++++++++++++++++
 tests/test_hook/test_checkpoint_hook.py | 131 +++++++++++++++
 3 files changed, 339 insertions(+), 1 deletion(-)
 create mode 100644 mmengine/hooks/checkpoint_hook.py
 create mode 100644 tests/test_hook/test_checkpoint_hook.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index 4bb2b676..a91f093b 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .empty_cache_hook import EmptyCacheHook
+from .checkpoint_hook import CheckpointHook
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
 from .optimizer_hook import OptimizerHook
@@ -8,5 +9,5 @@ from .sampler_seed_hook import DistSamplerSeedHook
 
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
-    'OptimizerHook', 'EmptyCacheHook'
+    'OptimizerHook', 'EmptyCacheHook', 'CheckpointHook'
 ]
diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
new file mode 100644
index 00000000..7baa99ed
--- /dev/null
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from pathlib import Path
+from typing import Optional, Sequence, Union
+
+from mmengine.data import BaseDataSample
+from mmengine.fileio import FileClient
+from mmengine.registry import HOOKS
+from .hook import Hook
+
+
+@HOOKS.register_module()
+class CheckpointHook(Hook):
+    """Save checkpoints periodically.
+
+    Args:
+        interval (int): The saving period. If ``by_epoch=True``, interval
+            indicates epochs, otherwise it indicates iterations.
+            Default: -1, which means "never".
+        by_epoch (bool): Saving checkpoints by epoch or by iteration.
+            Default: True.
+        save_optimizer (bool): Whether to save optimizer state_dict in the
+            checkpoint. It is usually used for resuming experiments.
+            Default: True.
+        out_dir (str, optional | Path): The root directory to save checkpoints.
+            If not specified, ``runner.work_dir`` will be used by default. If
+            specified, the ``out_dir`` will be the concatenation of ``out_dir``
+            and the last level directory of ``runner.work_dir``. For example,
+            if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is
+            ``./work_dir/cur_exp``, then the ckpt will be saved in
+            ``./tmp/cur_exp``. Deafule to None.
+        max_keep_ckpts (int): The maximum checkpoints to keep.
+            In some cases we want only the latest few checkpoints and would
+            like to delete old ones to save the disk space.
+            Default: -1, which means unlimited.
+        save_last (bool): Whether to force the last checkpoint to be
+            saved regardless of interval. Default: True.
+        sync_buffer (bool): Whether to synchronize buffers in
+            different gpus. Default: False.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+    """
+
+    def __init__(self,
+                 interval: int = -1,
+                 by_epoch: bool = True,
+                 save_optimizer: bool = True,
+                 out_dir: Union[str, Path] = None,
+                 max_keep_ckpts: int = -1,
+                 save_last: bool = True,
+                 sync_buffer: bool = False,
+                 file_client_args: Optional[dict] = None,
+                 **kwargs) -> None:
+        self.interval = interval
+        self.by_epoch = by_epoch
+        self.save_optimizer = save_optimizer
+        self.out_dir = out_dir
+        self.max_keep_ckpts = max_keep_ckpts
+        self.save_last = save_last
+        self.args = kwargs
+        self.sync_buffer = sync_buffer
+        self.file_client_args = file_client_args
+
+    def before_run(self, runner: object) -> None:
+        """Finish all operations, related to checkpoint.
+
+        This function will get the appropriate file client, and the directory
+        to save these checkpoints of the model.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        if not self.out_dir:
+            self.out_dir = runner.work_dir  # type: ignore
+
+        self.file_client = FileClient.infer_client(self.file_client_args,
+                                                   self.out_dir)
+
+        # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+        # `self.out_dir` is set so the final `self.out_dir` is the
+        # concatenation of `self.out_dir` and the last level directory of
+        # `runner.work_dir`
+        if self.out_dir != runner.work_dir:  # type: ignore
+            basename = osp.basename(
+                runner.work_dir.rstrip(  # type: ignore
+                    osp.sep))
+            self.out_dir = self.file_client.join_path(
+                self.out_dir,  # type: ignore
+                basename)
+
+        runner.logger.info((  # type: ignore
+            f'Checkpoints will be saved to {self.out_dir} by '
+            f'{self.file_client.name}.'))
+
+        # disable the create_symlink option because some file backends do not
+        # allow to create a symlink
+        if 'create_symlink' in self.args:
+            if self.args[
+                    'create_symlink'] and not self.file_client.allow_symlink:
+                self.args['create_symlink'] = False
+                warnings.warn(
+                    ('create_symlink is set as True by the user but is changed'
+                     'to be False because creating symbolic link is not '
+                     f'allowed in {self.file_client.name}'))
+        else:
+            self.args['create_symlink'] = self.file_client.allow_symlink
+
+    def after_train_epoch(self, runner: object) -> None:
+        """Save the checkpoint and synchronize buffers after each epoch.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        if not self.by_epoch:
+            return
+
+        # save checkpoint for following cases:
+        # 1. every ``self.interval`` epochs
+        # 2. reach the last epoch of training
+        if self.every_n_epochs(
+                runner, self.interval) or (self.save_last
+                                           and self.is_last_epoch(runner)):
+            runner.logger.info(  # type: ignore
+                f'Saving checkpoint at \
+                    {runner.epoch + 1} epochs')  # type: ignore
+            if self.sync_buffer:
+                pass
+                # TODO
+            self._save_checkpoint(runner)
+
+    # TODO Add master_only decorator
+    def _save_checkpoint(self, runner: object) -> None:
+        """Save the current checkpoint and delete outdated checkpoint.
+
+        Args:
+            runner (object): The runner of the training process.
+        """
+        runner.save_checkpoint(  # type: ignore
+            self.out_dir,
+            save_optimizer=self.save_optimizer,
+            **self.args)
+        if runner.meta is not None:  # type: ignore
+            if self.by_epoch:
+                cur_ckpt_filename = self.args.get(
+                    'filename_tmpl',
+                    'epoch_{}.pth').format(runner.epoch + 1)  # type: ignore
+            else:
+                cur_ckpt_filename = self.args.get(
+                    'filename_tmpl',
+                    'iter_{}.pth').format(runner.iter + 1)  # type: ignore
+            runner.meta.setdefault('hook_msgs', dict())  # type: ignore
+            runner.meta['hook_msgs'][  # type: ignore
+                'last_ckpt'] = self.file_client.join_path(
+                    self.out_dir, cur_ckpt_filename)  # type: ignore
+        # remove other checkpoints
+        if self.max_keep_ckpts > 0:
+            if self.by_epoch:
+                name = 'epoch_{}.pth'
+                current_ckpt = runner.epoch + 1  # type: ignore
+            else:
+                name = 'iter_{}.pth'
+                current_ckpt = runner.iter + 1  # type: ignore
+            redundant_ckpts = range(
+                current_ckpt - self.max_keep_ckpts * self.interval, 0,
+                -self.interval)
+            filename_tmpl = self.args.get('filename_tmpl', name)
+            for _step in redundant_ckpts:
+                ckpt_path = self.file_client.join_path(
+                    self.out_dir, filename_tmpl.format(_step))  # type: ignore
+                if self.file_client.isfile(ckpt_path):
+                    self.file_client.remove(ckpt_path)
+                else:
+                    break
+
+    def after_train_iter(
+            self,
+            runner: object,
+            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Save the checkpoint and synchronize buffers after each iteration.
+
+        Args:
+            runner (object): The runner of the training process.
+            data_batch (Sequence[BaseDataSample]): Data from dataloader.
+                Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+        """
+        if self.by_epoch:
+            return
+
+        # save checkpoint for following cases:
+        # 1. every ``self.interval`` iterations
+        # 2. reach the last iteration of training
+        if self.every_n_iters(
+                runner, self.interval) or (self.save_last
+                                           and self.is_last_iter(runner)):
+            runner.logger.info(  # type: ignore
+                f'Saving checkpoint at \
+                    {runner.iter + 1} iterations')  # type: ignore
+            if self.sync_buffer:
+                pass
+                # TODO
+            self._save_checkpoint(runner)
diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py
new file mode 100644
index 00000000..f45b433a
--- /dev/null
+++ b/tests/test_hook/test_checkpoint_hook.py
@@ -0,0 +1,131 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import sys
+from tempfile import TemporaryDirectory
+from unittest.mock import Mock, patch
+
+from mmengine.hooks import CheckpointHook
+
+sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
+
+
+class MockPetrel:
+
+    _allow_symlink = False
+
+    def __init__(self):
+        pass
+
+    @property
+    def name(self):
+        return self.__class__.__name__
+
+    @property
+    def allow_symlink(self):
+        return self._allow_symlink
+
+
+prefix_to_backends = {'s3': MockPetrel}
+
+
+class TestCheckpointHook:
+
+    @patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
+    def test_before_run(self):
+        runner = Mock()
+        runner.work_dir = './tmp'
+
+        # the out_dir of the checkpoint hook is None
+        checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
+        checkpoint_hook.before_run(runner)
+        assert checkpoint_hook.out_dir == runner.work_dir
+
+        # the out_dir of the checkpoint hook is not None
+        checkpoint_hook = CheckpointHook(
+            interval=1, by_epoch=True, out_dir='test_dir')
+        checkpoint_hook.before_run(runner)
+        assert checkpoint_hook.out_dir == 'test_dir/tmp'
+
+        # create_symlink in args and create_symlink is True
+        checkpoint_hook = CheckpointHook(
+            interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
+        checkpoint_hook.before_run(runner)
+        assert checkpoint_hook.args['create_symlink']
+
+        runner.work_dir = 's3://path/of/file'
+        checkpoint_hook = CheckpointHook(
+            interval=1, by_epoch=True, create_symlink=True)
+        checkpoint_hook.before_run(runner)
+        assert not checkpoint_hook.args['create_symlink']
+
+    def test_after_train_epoch(self):
+        runner = Mock()
+        runner.work_dir = './tmp'
+        runner.epoch = 9
+        runner.meta = dict()
+
+        # by epoch is True
+        checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
+        checkpoint_hook.before_run(runner)
+        checkpoint_hook.after_train_epoch(runner)
+        assert (runner.epoch + 1) % 2 == 0
+        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
+
+        # epoch can not be evenly divided by 2
+        runner.epoch = 10
+        checkpoint_hook.after_train_epoch(runner)
+        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
+
+        # by epoch is False
+        runner.epoch = 9
+        runner.meta = dict()
+        checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
+        checkpoint_hook.before_run(runner)
+        checkpoint_hook.after_train_epoch(runner)
+        assert runner.meta.get('hook_msgs', None) is None
+
+        # max_keep_ckpts > 0
+        with TemporaryDirectory() as tempo_dir:
+            runner.work_dir = tempo_dir
+            os.system(f'touch {tempo_dir}/epoch_8.pth')
+            checkpoint_hook = CheckpointHook(
+                interval=2, by_epoch=True, max_keep_ckpts=1)
+            checkpoint_hook.before_run(runner)
+            checkpoint_hook.after_train_epoch(runner)
+            assert (runner.epoch + 1) % 2 == 0
+            assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
+
+    def test_after_train_iter(self):
+        runner = Mock()
+        runner.work_dir = './tmp'
+        runner.iter = 9
+        runner.meta = dict()
+
+        # by epoch is True
+        checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
+        checkpoint_hook.before_run(runner)
+        checkpoint_hook.after_train_iter(runner)
+        assert runner.meta.get('hook_msgs', None) is None
+
+        # by epoch is False
+        checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
+        checkpoint_hook.before_run(runner)
+        checkpoint_hook.after_train_iter(runner)
+        assert (runner.iter + 1) % 2 == 0
+        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
+
+        # epoch can not be evenly divided by 2
+        runner.iter = 10
+        checkpoint_hook.after_train_epoch(runner)
+        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
+
+        # max_keep_ckpts > 0
+        runner.iter = 9
+        with TemporaryDirectory() as tempo_dir:
+            runner.work_dir = tempo_dir
+            os.system(f'touch {tempo_dir}/iter_8.pth')
+            checkpoint_hook = CheckpointHook(
+                interval=2, by_epoch=False, max_keep_ckpts=1)
+            checkpoint_hook.before_run(runner)
+            checkpoint_hook.after_train_iter(runner)
+            assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
-- 
GitLab