Skip to content
Snippets Groups Projects
Unverified Commit 50078256 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] change CheckPointHook before_run to before train (#214)

* change CheckPointHook before_run to before train

* using tmp_path in each checkpointhook test case
parent a1adbff1
No related branches found
No related tags found
No related merge requests found
...@@ -69,7 +69,7 @@ class CheckpointHook(Hook): ...@@ -69,7 +69,7 @@ class CheckpointHook(Hook):
self.args = kwargs self.args = kwargs
self.file_client_args = file_client_args self.file_client_args = file_client_args
def before_run(self, runner) -> None: def before_train(self, runner) -> None:
"""Finish all operations, related to checkpoint. """Finish all operations, related to checkpoint.
This function will get the appropriate file client, and the directory This function will get the appropriate file client, and the directory
...@@ -78,12 +78,11 @@ class CheckpointHook(Hook): ...@@ -78,12 +78,11 @@ class CheckpointHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if not self.out_dir: if self.out_dir is None:
self.out_dir = runner.work_dir self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args, self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir) self.out_dir)
# if `self.out_dir` is not equal to `runner.work_dir`, it means that # if `self.out_dir` is not equal to `runner.work_dir`, it means that
# `self.out_dir` is set so the final `self.out_dir` is the # `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of # concatenation of `self.out_dir` and the last level directory of
...@@ -186,8 +185,7 @@ class CheckpointHook(Hook): ...@@ -186,8 +185,7 @@ class CheckpointHook(Hook):
batch_idx (int): The index of the current batch in the train loop. batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader. data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None. Defaults to None.
outputs (dict, optional): Outputs from model. outputs (dict, optional): Outputs from model. Defaults to None.
Defaults to None.
""" """
if self.by_epoch: if self.by_epoch:
return return
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import sys import os.path as osp
from tempfile import TemporaryDirectory
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from mmengine.hooks import CheckpointHook from mmengine.hooks import CheckpointHook
sys.modules['file_client'] = sys.modules['mmengine.fileio.file_client']
class MockPetrel: class MockPetrel:
...@@ -30,75 +27,80 @@ prefix_to_backends = {'s3': MockPetrel} ...@@ -30,75 +27,80 @@ prefix_to_backends = {'s3': MockPetrel}
class TestCheckpointHook: class TestCheckpointHook:
@patch('file_client.FileClient._prefix_to_backends', prefix_to_backends) @patch('mmengine.fileio.file_client.FileClient._prefix_to_backends',
def test_before_run(self): prefix_to_backends)
def test_before_train(self, tmp_path):
runner = Mock() runner = Mock()
runner.work_dir = './tmp' work_dir = str(tmp_path)
runner.work_dir = work_dir
# the out_dir of the checkpoint hook is None # the out_dir of the checkpoint hook is None
checkpoint_hook = CheckpointHook(interval=1, by_epoch=True) checkpoint_hook = CheckpointHook(interval=1, by_epoch=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.out_dir == runner.work_dir assert checkpoint_hook.out_dir == runner.work_dir
# the out_dir of the checkpoint hook is not None # the out_dir of the checkpoint hook is not None
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir') interval=1, by_epoch=True, out_dir='test_dir')
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.out_dir == 'test_dir/tmp' assert checkpoint_hook.out_dir == (
f'test_dir/{osp.basename(work_dir)}')
# create_symlink in args and create_symlink is True # create_symlink in args and create_symlink is True
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True) interval=1, by_epoch=True, out_dir='test_dir', create_symlink=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
assert checkpoint_hook.args['create_symlink'] assert checkpoint_hook.args['create_symlink']
runner.work_dir = 's3://path/of/file' runner.work_dir = 's3://path/of/file'
checkpoint_hook = CheckpointHook( checkpoint_hook = CheckpointHook(
interval=1, by_epoch=True, create_symlink=True) interval=1, by_epoch=True, create_symlink=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
assert not checkpoint_hook.args['create_symlink'] assert not checkpoint_hook.args['create_symlink']
def test_after_train_epoch(self): def test_after_train_epoch(self, tmp_path):
runner = Mock() runner = Mock()
runner.work_dir = './tmp' work_dir = str(tmp_path)
runner.work_dir = tmp_path
runner.epoch = 9 runner.epoch = 9
runner.meta = dict() runner.meta = dict()
runner.model = Mock() runner.model = Mock()
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert (runner.epoch + 1) % 2 == 0 assert (runner.epoch + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth' assert runner.meta['hook_msgs']['last_ckpt'] == (
f'{work_dir}/epoch_10.pth')
# epoch can not be evenly divided by 2 # epoch can not be evenly divided by 2
runner.epoch = 10 runner.epoch = 10
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/epoch_10.pth' assert runner.meta['hook_msgs']['last_ckpt'] == (
f'{work_dir}/epoch_10.pth')
# by epoch is False # by epoch is False
runner.epoch = 9 runner.epoch = 9
runner.meta = dict() runner.meta = dict()
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert runner.meta.get('hook_msgs', None) is None assert runner.meta.get('hook_msgs', None) is None
# max_keep_ckpts > 0 # max_keep_ckpts > 0
with TemporaryDirectory() as tempo_dir: runner.work_dir = work_dir
runner.work_dir = tempo_dir os.system(f'touch {work_dir}/epoch_8.pth')
os.system(f'touch {tempo_dir}/epoch_8.pth') checkpoint_hook = CheckpointHook(
checkpoint_hook = CheckpointHook( interval=2, by_epoch=True, max_keep_ckpts=1)
interval=2, by_epoch=True, max_keep_ckpts=1) checkpoint_hook.before_train(runner)
checkpoint_hook.before_run(runner) checkpoint_hook.after_train_epoch(runner)
checkpoint_hook.after_train_epoch(runner) assert (runner.epoch + 1) % 2 == 0
assert (runner.epoch + 1) % 2 == 0 assert not os.path.exists(f'{work_dir}/epoch_8.pth')
assert not os.path.exists(f'{tempo_dir}/epoch_8.pth')
def test_after_train_iter(self, tmp_path):
def test_after_train_iter(self): work_dir = str(tmp_path)
runner = Mock() runner = Mock()
runner.work_dir = './tmp' runner.work_dir = str(work_dir)
runner.iter = 9 runner.iter = 9
batch_idx = 9 batch_idx = 9
runner.meta = dict() runner.meta = dict()
...@@ -106,29 +108,30 @@ class TestCheckpointHook: ...@@ -106,29 +108,30 @@ class TestCheckpointHook:
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert runner.meta.get('hook_msgs', None) is None assert runner.meta.get('hook_msgs', None) is None
# by epoch is False # by epoch is False
checkpoint_hook = CheckpointHook(interval=2, by_epoch=False) checkpoint_hook = CheckpointHook(interval=2, by_epoch=False)
checkpoint_hook.before_run(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert (runner.iter + 1) % 2 == 0 assert (runner.iter + 1) % 2 == 0
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' assert runner.meta['hook_msgs']['last_ckpt'] == (
f'{work_dir}/iter_10.pth')
# epoch can not be evenly divided by 2 # epoch can not be evenly divided by 2
runner.iter = 10 runner.iter = 10
checkpoint_hook.after_train_epoch(runner) checkpoint_hook.after_train_epoch(runner)
assert runner.meta['hook_msgs']['last_ckpt'] == './tmp/iter_10.pth' assert runner.meta['hook_msgs']['last_ckpt'] == (
f'{work_dir}/iter_10.pth')
# max_keep_ckpts > 0 # max_keep_ckpts > 0
runner.iter = 9 runner.iter = 9
with TemporaryDirectory() as tempo_dir: runner.work_dir = work_dir
runner.work_dir = tempo_dir os.system(f'touch {work_dir}/iter_8.pth')
os.system(f'touch {tempo_dir}/iter_8.pth') checkpoint_hook = CheckpointHook(
checkpoint_hook = CheckpointHook( interval=2, by_epoch=False, max_keep_ckpts=1)
interval=2, by_epoch=False, max_keep_ckpts=1) checkpoint_hook.before_train(runner)
checkpoint_hook.before_run(runner) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) assert not os.path.exists(f'{work_dir}/iter_8.pth')
assert not os.path.exists(f'{tempo_dir}/iter_8.pth')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment