diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 71929c0cdb50b9a75d21c3067994c602e78eb0c4..17e19737917dbe3fe36721950f5f7be7e2e0fbb4 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample +from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.registry import HOOKS from .hook import Hook @@ -41,8 +42,6 @@ class CheckpointHook(Hook): Default: -1, which means unlimited. save_last (bool): Whether to force the last checkpoint to be saved regardless of interval. Default: True. - sync_buffer (bool): Whether to synchronize buffers in - different gpus. Default: False. file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmcv.fileio.FileClient` for details. Default: None. @@ -59,7 +58,6 @@ class CheckpointHook(Hook): out_dir: Optional[Union[str, Path]] = None, max_keep_ckpts: int = -1, save_last: bool = True, - sync_buffer: bool = False, file_client_args: Optional[dict] = None, **kwargs) -> None: self.interval = interval @@ -70,7 +68,6 @@ class CheckpointHook(Hook): self.max_keep_ckpts = max_keep_ckpts self.save_last = save_last self.args = kwargs - self.sync_buffer = sync_buffer self.file_client_args = file_client_args def before_run(self, runner) -> None: @@ -129,12 +126,9 @@ class CheckpointHook(Hook): self.save_last and self.is_last_train_epoch(runner)): runner.logger.info( f'Saving checkpoint at {runner.epoch + 1} epochs') - if self.sync_buffer: - pass - # TODO self._save_checkpoint(runner) - # TODO Add master_only decorator + @master_only def _save_checkpoint(self, runner) -> None: """Save the current checkpoint and delete outdated checkpoint. @@ -204,7 +198,4 @@ class CheckpointHook(Hook): (self.save_last and self.is_last_iter(runner, mode='train')): runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') - if self.sync_buffer: - pass - # TODO self._save_checkpoint(runner) diff --git a/tests/test_hook/test_checkpoint_hook.py b/tests/test_hook/test_checkpoint_hook.py index f45b433aab46c70956a7220768df5bc127093324..3cf661c01ed760a01d7cb1f1a28009044664cf45 100644 --- a/tests/test_hook/test_checkpoint_hook.py +++ b/tests/test_hook/test_checkpoint_hook.py @@ -63,6 +63,7 @@ class TestCheckpointHook: runner.work_dir = './tmp' runner.epoch = 9 runner.meta = dict() + runner.model = Mock() # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True) @@ -100,6 +101,7 @@ class TestCheckpointHook: runner.work_dir = './tmp' runner.iter = 9 runner.meta = dict() + runner.model = Mock() # by epoch is True checkpoint_hook = CheckpointHook(interval=2, by_epoch=True)