Skip to content
Snippets Groups Projects
Unverified Commit 26f24296 authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Add dist semantics in checkpoint hook (#131)

* [Feature]: Add dist semantics in checkpoint hook

* [Fix]: Delete sync buffer in checkpoint hook
parent e4859030
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ from pathlib import Path ...@@ -5,6 +5,7 @@ from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.dist import master_only
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
...@@ -41,8 +42,6 @@ class CheckpointHook(Hook): ...@@ -41,8 +42,6 @@ class CheckpointHook(Hook):
Default: -1, which means unlimited. Default: -1, which means unlimited.
save_last (bool): Whether to force the last checkpoint to be save_last (bool): Whether to force the last checkpoint to be
saved regardless of interval. Default: True. saved regardless of interval. Default: True.
sync_buffer (bool): Whether to synchronize buffers in
different gpus. Default: False.
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details. FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None. Default: None.
...@@ -59,7 +58,6 @@ class CheckpointHook(Hook): ...@@ -59,7 +58,6 @@ class CheckpointHook(Hook):
out_dir: Optional[Union[str, Path]] = None, out_dir: Optional[Union[str, Path]] = None,
max_keep_ckpts: int = -1, max_keep_ckpts: int = -1,
save_last: bool = True, save_last: bool = True,
sync_buffer: bool = False,
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
**kwargs) -> None: **kwargs) -> None:
self.interval = interval self.interval = interval
...@@ -70,7 +68,6 @@ class CheckpointHook(Hook): ...@@ -70,7 +68,6 @@ class CheckpointHook(Hook):
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last self.save_last = save_last
self.args = kwargs self.args = kwargs
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args self.file_client_args = file_client_args
def before_run(self, runner) -> None: def before_run(self, runner) -> None:
...@@ -129,12 +126,9 @@ class CheckpointHook(Hook): ...@@ -129,12 +126,9 @@ class CheckpointHook(Hook):
self.save_last and self.is_last_train_epoch(runner)): self.save_last and self.is_last_train_epoch(runner)):
runner.logger.info( runner.logger.info(
f'Saving checkpoint at {runner.epoch + 1} epochs') f'Saving checkpoint at {runner.epoch + 1} epochs')
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner) self._save_checkpoint(runner)
# TODO Add master_only decorator @master_only
def _save_checkpoint(self, runner) -> None: def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint. """Save the current checkpoint and delete outdated checkpoint.
...@@ -204,7 +198,4 @@ class CheckpointHook(Hook): ...@@ -204,7 +198,4 @@ class CheckpointHook(Hook):
(self.save_last and self.is_last_iter(runner, mode='train')): (self.save_last and self.is_last_iter(runner, mode='train')):
runner.logger.info( runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations') f'Saving checkpoint at {runner.iter + 1} iterations')
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner) self._save_checkpoint(runner)
...@@ -63,6 +63,7 @@ class TestCheckpointHook: ...@@ -63,6 +63,7 @@ class TestCheckpointHook:
runner.work_dir = './tmp' runner.work_dir = './tmp'
runner.epoch = 9 runner.epoch = 9
runner.meta = dict() runner.meta = dict()
runner.model = Mock()
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
...@@ -100,6 +101,7 @@ class TestCheckpointHook: ...@@ -100,6 +101,7 @@ class TestCheckpointHook:
runner.work_dir = './tmp' runner.work_dir = './tmp'
runner.iter = 9 runner.iter = 9
runner.meta = dict() runner.meta = dict()
runner.model = Mock()
# by epoch is True # by epoch is True
checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment