Skip to content
Snippets Groups Projects
Unverified Commit c64243aa authored by Qian Zhao's avatar Qian Zhao Committed by GitHub
Browse files

[Fix] CheckpointHook behavior incorrect if given `filename_tmpl` argument (#518)

parent e56b1736
No related branches found
No related tags found
No related merge requests found
...@@ -74,6 +74,12 @@ class CheckpointHook(Hook): ...@@ -74,6 +74,12 @@ class CheckpointHook(Hook):
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details. FileClient. See :class:`mmcv.fileio.FileClient` for details.
Defaults to None. Defaults to None.
filename_tmpl (str, optional): String template to indicate checkpoint
name. If specified, must contain one and only one "{}", which will
be replaced with ``epoch + 1`` if ``by_epoch=True`` else
``iteration + 1``.
Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth"
accordingly.
Examples: Examples:
>>> # Save best based on single metric >>> # Save best based on single metric
...@@ -116,6 +122,7 @@ class CheckpointHook(Hook): ...@@ -116,6 +122,7 @@ class CheckpointHook(Hook):
greater_keys: Optional[Sequence[str]] = None, greater_keys: Optional[Sequence[str]] = None,
less_keys: Optional[Sequence[str]] = None, less_keys: Optional[Sequence[str]] = None,
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
filename_tmpl: Optional[str] = None,
**kwargs) -> None: **kwargs) -> None:
self.interval = interval self.interval = interval
self.by_epoch = by_epoch self.by_epoch = by_epoch
...@@ -124,8 +131,15 @@ class CheckpointHook(Hook): ...@@ -124,8 +131,15 @@ class CheckpointHook(Hook):
self.out_dir = out_dir # type: ignore self.out_dir = out_dir # type: ignore
self.max_keep_ckpts = max_keep_ckpts self.max_keep_ckpts = max_keep_ckpts
self.save_last = save_last self.save_last = save_last
self.args = kwargs
self.file_client_args = file_client_args self.file_client_args = file_client_args
if filename_tmpl is None:
if self.by_epoch:
self.filename_tmpl = 'epoch_{}.pth'
else:
self.filename_tmpl = 'iter_{}.pth'
else:
self.filename_tmpl = filename_tmpl
self.args = kwargs
# save best logic # save best logic
assert (isinstance(save_best, str) or is_list_of(save_best, str) assert (isinstance(save_best, str) or is_list_of(save_best, str)
...@@ -277,11 +291,9 @@ class CheckpointHook(Hook): ...@@ -277,11 +291,9 @@ class CheckpointHook(Hook):
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self.by_epoch: if self.by_epoch:
ckpt_filename = self.args.get( ckpt_filename = self.filename_tmpl.format(runner.epoch + 1)
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else: else:
ckpt_filename = self.args.get( ckpt_filename = self.filename_tmpl.format(runner.iter + 1)
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.save_checkpoint( runner.save_checkpoint(
self.out_dir, self.out_dir,
...@@ -299,18 +311,15 @@ class CheckpointHook(Hook): ...@@ -299,18 +311,15 @@ class CheckpointHook(Hook):
# remove other checkpoints # remove other checkpoints
if self.max_keep_ckpts > 0: if self.max_keep_ckpts > 0:
if self.by_epoch: if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1 current_ckpt = runner.epoch + 1
else: else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1 current_ckpt = runner.iter + 1
redundant_ckpts = range( redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0, current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval) -self.interval)
filename_tmpl = self.args.get('filename_tmpl', name)
for _step in redundant_ckpts: for _step in redundant_ckpts:
ckpt_path = self.file_client.join_path( ckpt_path = self.file_client.join_path(
self.out_dir, filename_tmpl.format(_step)) self.out_dir, self.filename_tmpl.format(_step))
if self.file_client.isfile(ckpt_path): if self.file_client.isfile(ckpt_path):
self.file_client.remove(ckpt_path) self.file_client.remove(ckpt_path)
else: else:
...@@ -334,12 +343,10 @@ class CheckpointHook(Hook): ...@@ -334,12 +343,10 @@ class CheckpointHook(Hook):
return return
if self.by_epoch: if self.by_epoch:
ckpt_filename = self.args.get('filename_tmpl', ckpt_filename = self.filename_tmpl.format(runner.epoch)
'epoch_{}.pth').format(runner.epoch)
cur_type, cur_time = 'epoch', runner.epoch cur_type, cur_time = 'epoch', runner.epoch
else: else:
ckpt_filename = self.args.get('filename_tmpl', ckpt_filename = self.filename_tmpl.format(runner.iter)
'iter_{}.pth').format(runner.iter)
cur_type, cur_time = 'iter', runner.iter cur_type, cur_time = 'iter', runner.iter
# handle auto in self.key_indicators and self.rules before the loop # handle auto in self.key_indicators and self.rules before the loop
......
...@@ -4,9 +4,71 @@ import os.path as osp ...@@ -4,9 +4,71 @@ import os.path as osp
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from mmengine.evaluator import BaseMetric
from mmengine.hooks import CheckpointHook from mmengine.hooks import CheckpointHook
from mmengine.logging import MessageHub from mmengine.logging import MessageHub
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
else:
return outputs
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
class TriangleMetric(BaseMetric):
default_prefix: str = 'test'
def __init__(self, length):
super().__init__()
self.length = length
self.best_idx = length // 2
self.cur_idx = 0
def process(self, *args, **kwargs):
self.results.append(0)
def compute_metrics(self, *args, **kwargs):
self.cur_idx += 1
acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length
return dict(acc=acc)
class MockPetrel: class MockPetrel:
...@@ -370,3 +432,40 @@ class TestCheckpointHook: ...@@ -370,3 +432,40 @@ class TestCheckpointHook:
checkpoint_hook.before_train(runner) checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx) checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{work_dir}/iter_8.pth') assert not os.path.exists(f'{work_dir}/iter_8.pth')
def test_with_runner(self, tmp_path):
max_epoch = 10
work_dir = osp.join(str(tmp_path), 'runner_test')
tmpl = '{}.pth'
save_interval = 2
checkpoint_cfg = dict(
type='CheckpointHook',
interval=save_interval,
filename_tmpl=tmpl,
by_epoch=True)
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(
by_epoch=True, max_epochs=max_epoch, val_interval=1),
val_cfg=dict(),
default_hooks=dict(checkpoint=checkpoint_cfg))
runner.train()
for epoch in range(max_epoch):
if epoch % save_interval != 0 or epoch == 0:
continue
path = osp.join(work_dir, tmpl.format(epoch))
assert osp.isfile(path=path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment