Skip to content
Snippets Groups Projects
Unverified Commit f1aca8e3 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Failed to remove the previous best checkpoints (#1086)

* [Fix] Only reserve one best checkpoint

* [Fix] Only reserve one best checkpoint

* Fix unit test

* shutdown logging

* clean the save_checkpoint logic
parent 6ebb6f83
No related branches found
No related tags found
No related merge requests found
...@@ -479,9 +479,9 @@ class CheckpointHook(Hook): ...@@ -479,9 +479,9 @@ class CheckpointHook(Hook):
runner.message_hub.update_info(best_score_key, best_score) runner.message_hub.update_info(best_score_key, best_score)
if best_ckpt_path and \ if best_ckpt_path and \
self.file_client.isfile(best_ckpt_path) and \ self.file_backend.isfile(best_ckpt_path) and \
is_main_process(): is_main_process():
self.file_client.remove(best_ckpt_path) self.file_backend.remove(best_ckpt_path)
runner.logger.info( runner.logger.info(
f'The previous best checkpoint {best_ckpt_path} ' f'The previous best checkpoint {best_ckpt_path} '
'is removed') 'is removed')
...@@ -490,13 +490,13 @@ class CheckpointHook(Hook): ...@@ -490,13 +490,13 @@ class CheckpointHook(Hook):
# Replace illegal characters for filename with `_` # Replace illegal characters for filename with `_`
best_ckpt_name = best_ckpt_name.replace('/', '_') best_ckpt_name = best_ckpt_name.replace('/', '_')
if len(self.key_indicators) == 1: if len(self.key_indicators) == 1:
self.best_ckpt_path = self.file_client.join_path( # type: ignore # noqa: E501 self.best_ckpt_path = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name) self.out_dir, best_ckpt_name)
runner.message_hub.update_info(runtime_best_ckpt_key, runner.message_hub.update_info(runtime_best_ckpt_key,
self.best_ckpt_path) self.best_ckpt_path)
else: else:
self.best_ckpt_path_dict[ self.best_ckpt_path_dict[
key_indicator] = self.file_client.join_path( # type: ignore # noqa: E501 key_indicator] = self.file_backend.join_path( # type: ignore # noqa: E501
self.out_dir, best_ckpt_name) self.out_dir, best_ckpt_name)
runner.message_hub.update_info( runner.message_hub.update_info(
runtime_best_ckpt_key, runtime_best_ckpt_key,
......
...@@ -2191,7 +2191,11 @@ class Runner: ...@@ -2191,7 +2191,11 @@ class Runner:
checkpoint['param_schedulers'].append(state_dict) checkpoint['param_schedulers'].append(state_dict)
self.call_hook('before_save_checkpoint', checkpoint=checkpoint) self.call_hook('before_save_checkpoint', checkpoint=checkpoint)
save_checkpoint(checkpoint, filepath) save_checkpoint(
checkpoint,
filepath,
file_client_args=file_client_args,
backend_args=backend_args)
@master_only @master_only
def dump_config(self) -> None: def dump_config(self) -> None:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import copy import copy
import logging import logging
import os import os
import shutil
import tempfile import tempfile
import time import time
from unittest import TestCase from unittest import TestCase
...@@ -184,3 +185,12 @@ class RunnerTestCase(TestCase): ...@@ -184,3 +185,12 @@ class RunnerTestCase(TestCase):
os.environ['RANK'] = self.dist_cfg['RANK'] os.environ['RANK'] = self.dist_cfg['RANK']
os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE'] os.environ['WORLD_SIZE'] = self.dist_cfg['WORLD_SIZE']
os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK'] os.environ['LOCAL_RANK'] = self.dist_cfg['LOCAL_RANK']
def clear_work_dir(self):
logging.shutdown()
for filename in os.listdir(self.temp_dir.name):
filepath = os.path.join(self.temp_dir.name, filename)
if os.path.isfile(filepath):
os.remove(filepath)
else:
shutil.rmtree(filepath)
...@@ -3,6 +3,8 @@ import copy ...@@ -3,6 +3,8 @@ import copy
import os import os
import os.path as osp import os.path as osp
import re import re
import sys
from unittest.mock import MagicMock, patch
import torch import torch
from parameterized import parameterized from parameterized import parameterized
...@@ -312,6 +314,54 @@ class TestCheckpointHook(RunnerTestCase): ...@@ -312,6 +314,54 @@ class TestCheckpointHook(RunnerTestCase):
self.assertFalse( self.assertFalse(
osp.isfile(osp.join(runner.work_dir, 'last_checkpoint'))) osp.isfile(osp.join(runner.work_dir, 'last_checkpoint')))
# There should only one best checkpoint be reserved
# dist backend
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
self.clear_work_dir()
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
checkpoint_hook = CheckpointHook(
interval=2, by_epoch=by_epoch, save_best='acc')
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
all_files = os.listdir(runner.work_dir)
best_ckpts = [
file for file in all_files if file.startswith('best')
]
self.assertTrue(len(best_ckpts) == 1)
# petrel backend
# TODO use real petrel oss bucket to test
petrel_client = MagicMock()
for by_epoch, cfg in [(True, self.epoch_based_cfg),
(False, self.iter_based_cfg)]:
isfile = MagicMock(return_value=True)
self.clear_work_dir()
with patch.dict(sys.modules, {'petrel_client': petrel_client}), \
patch('mmengine.fileio.backends.PetrelBackend.put') as put_mock, \
patch('mmengine.fileio.backends.PetrelBackend.remove') as remove_mock, \
patch('mmengine.fileio.backends.PetrelBackend.isfile') as isfile: # noqa: E501
cfg = copy.deepcopy(cfg)
runner = self.build_runner(cfg)
metrics = dict(acc=0.5)
petrel_client.client.Client = MagicMock(
return_value=petrel_client)
checkpoint_hook = CheckpointHook(
interval=2,
by_epoch=by_epoch,
save_best='acc',
backend_args=dict(backend='petrel'))
checkpoint_hook.before_train(runner)
checkpoint_hook.after_val_epoch(runner, metrics)
put_mock.assert_called_once()
metrics['acc'] += 0.1
runner.train_loop._epoch += 1
runner.train_loop._iter += 1
checkpoint_hook.after_val_epoch(runner, metrics)
isfile.assert_called_once()
remove_mock.assert_called_once()
def test_after_train_epoch(self): def test_after_train_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
......
...@@ -20,9 +20,6 @@ from mmengine.runner.checkpoint import (CheckpointLoader, ...@@ -20,9 +20,6 @@ from mmengine.runner.checkpoint import (CheckpointLoader,
load_from_local, load_from_pavi, load_from_local, load_from_pavi,
save_checkpoint) save_checkpoint)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
@MODEL_WRAPPERS.register_module() @MODEL_WRAPPERS.register_module()
class DDPWrapper: class DDPWrapper:
...@@ -150,9 +147,8 @@ def test_get_state_dict(): ...@@ -150,9 +147,8 @@ def test_get_state_dict():
wrapped_model.module.conv.module.bias) wrapped_model.module.conv.module.bias)
@patch.dict(sys.modules, {'pavi': MagicMock()})
def test_load_pavimodel_dist(): def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel() pavimodel = Mockpavimodel()
import pavi import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel) pavi.modelcloud.get = MagicMock(return_value=pavimodel)
...@@ -296,6 +292,7 @@ def test_load_checkpoint_metadata(): ...@@ -296,6 +292,7 @@ def test_load_checkpoint_metadata():
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
@patch.dict(sys.modules, {'petrel_client': MagicMock()})
def test_checkpoint_loader(): def test_checkpoint_loader():
filenames = [ filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment