From 50078256192fbe2f98a5643375a1d91468b4ba52 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Thu, 5 May 2022 20:08:07 +0800
Subject: [PATCH] [Fix] change CheckPointHook before_run to before train (#214)

* change CheckPointHook before_run to before train

* using tmp_path in each checkpointhook test case
---
 mmengine/hooks/checkpoint_hook.py       |  8 +--
 tests/test_hook/test_checkpoint_hook.py | 89 +++++++++++++------------
 2 files changed, 49 insertions(+), 48 deletions(-)

diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index 017784b9..3ed332bd 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -69,7 +69,7 @@ class CheckpointHook(Hook):
         self.args = kwargs
         self.file_client_args = file_client_args
 
-    def before_run(self, runner) -> None:
+    def before_train(self, runner) -> None:
         """Finish all operations, related to checkpoint.
 
         This function will get the appropriate file client, and the directory
@@ -78,12 +78,11 @@ class CheckpointHook(Hook):
         Args:
             runner (Runner): The runner of the training process.
         """
-        if not self.out_dir:
+        if self.out_dir is None:
             self.out_dir = runner.work_dir
 
         self.file_client = FileClient.infer_client(self.file_client_args,
                                                    self.out_dir)
-
         # if `self.out_dir` is not equal to `runner.work_dir`, it means that
         # `self.out_dir` is set so the final `self.out_dir` is the
         # concatenation of `self.out_dir` and the last level directory of
@@ -186,8 +185,7 @@ class CheckpointHook(Hook):
             batch_idx (int): The index of the current batch in the train loop.
             data_batch (Sequence[dict], optional): Data from dataloader.
                 Defaults to None.
-            outputs (dict, optional): Outputs from model.
-                Defaults to None.
+            outputs (dict, optional): Outputs from model. Defaults to None.
         """
         if self.by_epoch:
             return
diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py
index 7fabecd5..10f682cd 100644
--- a/tests/test_hook/test_checkpoint_hook.py
+++ b/tests/test_hook/test_checkpoint_hook.py
@@ -1,13 +1,10 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import os
-import sys
-from tempfile import TemporaryDirectory
+import os.path as osp
 from unittest.mock import Mock, patch
 
 from mmengine.hooks import CheckpointHook
 
-sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
-
 
 class MockPetrel:
 
@@ -30,75 +27,80 @@ prefix_to_backends = {'s3': MockPetrel}
 
 class TestCheckpointHook:
 
-    @patch('file_client.FileClient._prefix_to_backends', prefix_to_backends)
-    def test_before_run(self):
+    @patch('mmengine.fileio.file_client.FileClient._prefix_to_backends',
+           prefix_to_backends)
+    def test_before_train(self, tmp_path):
         runner = Mock()
-        runner.work_dir = './tmp'
+        work_dir = str(tmp_path)
+        runner.work_dir = work_dir
 
         # the out_dir of the checkpoint hook is None
         checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         assert checkpoint_hook.out_dir == runner.work_dir
 
         # the out_dir of the checkpoint hook is not None
         checkpoint_hook = CheckpointHook(
             interval=1, by_epoch=True, out_dir='test_dir')
-        checkpoint_hook.before_run(runner)
-        assert checkpoint_hook.out_dir == 'test_dir/tmp'
+        checkpoint_hook.before_train(runner)
+        assert checkpoint_hook.out_dir == (
+            f'test_dir/{osp.basename(work_dir)}')
 
         # create_symlink in args and create_symlink is True
         checkpoint_hook = CheckpointHook(
             interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         assert checkpoint_hook.args['create_symlink']
 
         runner.work_dir = 's3://path/of/file'
         checkpoint_hook = CheckpointHook(
             interval=1, by_epoch=True, create_symlink=True)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         assert not checkpoint_hook.args['create_symlink']
 
-    def test_after_train_epoch(self):
+    def test_after_train_epoch(self, tmp_path):
         runner = Mock()
-        runner.work_dir = './tmp'
+        work_dir = str(tmp_path)
+        runner.work_dir = tmp_path
         runner.epoch = 9
         runner.meta = dict()
         runner.model = Mock()
 
         # by epoch is True
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_epoch(runner)
         assert (runner.epoch + 1) % 2 == 0
-        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
-
+        assert runner.meta['hook_msgs']['last_ckpt'] == (
+            f'{work_dir}/epoch_10.pth')
         # epoch can not be evenly divided by 2
         runner.epoch = 10
         checkpoint_hook.after_train_epoch(runner)
-        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth'
+        assert runner.meta['hook_msgs']['last_ckpt'] == (
+            f'{work_dir}/epoch_10.pth')
 
         # by epoch is False
         runner.epoch = 9
         runner.meta = dict()
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_epoch(runner)
         assert runner.meta.get('hook_msgs', None) is None
 
         # max_keep_ckpts > 0
-        with TemporaryDirectory() as tempo_dir:
-            runner.work_dir = tempo_dir
-            os.system(f'touch {tempo_dir}/epoch_8.pth')
-            checkpoint_hook = CheckpointHook(
-                interval=2, by_epoch=True, max_keep_ckpts=1)
-            checkpoint_hook.before_run(runner)
-            checkpoint_hook.after_train_epoch(runner)
-            assert (runner.epoch + 1) % 2 == 0
-            assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
-
-    def test_after_train_iter(self):
+        runner.work_dir = work_dir
+        os.system(f'touch {work_dir}/epoch_8.pth')
+        checkpoint_hook = CheckpointHook(
+            interval=2, by_epoch=True, max_keep_ckpts=1)
+        checkpoint_hook.before_train(runner)
+        checkpoint_hook.after_train_epoch(runner)
+        assert (runner.epoch + 1) % 2 == 0
+        assert not os.path.exists(f'{work_dir}/epoch_8.pth')
+
+    def test_after_train_iter(self, tmp_path):
+        work_dir = str(tmp_path)
         runner = Mock()
-        runner.work_dir = './tmp'
+        runner.work_dir = str(work_dir)
         runner.iter = 9
         batch_idx = 9
         runner.meta = dict()
@@ -106,29 +108,30 @@ class TestCheckpointHook:
 
         # by epoch is True
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
         assert runner.meta.get('hook_msgs', None) is None
 
         # by epoch is False
         checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
-        checkpoint_hook.before_run(runner)
+        checkpoint_hook.before_train(runner)
         checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
         assert (runner.iter + 1) % 2 == 0
-        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
+        assert runner.meta['hook_msgs']['last_ckpt'] == (
+            f'{work_dir}/iter_10.pth')
 
         # epoch can not be evenly divided by 2
         runner.iter = 10
         checkpoint_hook.after_train_epoch(runner)
-        assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth'
+        assert runner.meta['hook_msgs']['last_ckpt'] == (
+            f'{work_dir}/iter_10.pth')
 
         # max_keep_ckpts > 0
         runner.iter = 9
-        with TemporaryDirectory() as tempo_dir:
-            runner.work_dir = tempo_dir
-            os.system(f'touch {tempo_dir}/iter_8.pth')
-            checkpoint_hook = CheckpointHook(
-                interval=2, by_epoch=False, max_keep_ckpts=1)
-            checkpoint_hook.before_run(runner)
-            checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
-            assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
+        runner.work_dir = work_dir
+        os.system(f'touch {work_dir}/iter_8.pth')
+        checkpoint_hook = CheckpointHook(
+            interval=2, by_epoch=False, max_keep_ckpts=1)
+        checkpoint_hook.before_train(runner)
+        checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
+        assert not os.path.exists(f'{work_dir}/iter_8.pth')
-- 
GitLab