Skip to content
Snippets Groups Projects
Unverified Commit 193b7fdf authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Refactor] Let unit tests not affect each other (#1169)

parent 5d4e7214
No related branches found
No related tags found
No related merge requests found
...@@ -85,7 +85,7 @@ class CheckpointHook(Hook): ...@@ -85,7 +85,7 @@ class CheckpointHook(Hook):
accordingly. accordingly.
backend_args (dict, optional): Arguments to instantiate the backend_args (dict, optional): Arguments to instantiate the
prefix of uri corresponding backend. Defaults to None. prefix of uri corresponding backend. Defaults to None.
New in v0.2.0. `New in version 0.2.0.`
published_keys (str, List[str], optional): If ``save_last`` is ``True`` published_keys (str, List[str], optional): If ``save_last`` is ``True``
or ``save_best`` is not ``None``, it will automatically or ``save_best`` is not ``None``, it will automatically
publish model with keys in the list after training. publish model with keys in the list after training.
......
...@@ -429,34 +429,37 @@ class TestCheckpointHook(RunnerTestCase): ...@@ -429,34 +429,37 @@ class TestCheckpointHook(RunnerTestCase):
@parameterized.expand([['iter'], ['epoch']]) @parameterized.expand([['iter'], ['epoch']])
def test_with_runner(self, training_type): def test_with_runner(self, training_type):
# Test interval in epoch based training common_cfg = getattr(self, f'{training_type}_based_cfg')
save_iterval = 2 setattr(common_cfg.train_cfg, f'max_{training_type}s', 11)
cfg = copy.deepcopy(getattr(self, f'{training_type}_based_cfg'))
setattr(cfg.train_cfg, f'max_{training_type}s', 11)
checkpoint_cfg = dict( checkpoint_cfg = dict(
type='CheckpointHook', type='CheckpointHook',
interval=save_iterval, interval=2,
by_epoch=training_type == 'epoch') by_epoch=training_type == 'epoch')
cfg.default_hooks = dict(checkpoint=checkpoint_cfg) common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg)
# Test interval in epoch based training
cfg = copy.deepcopy(common_cfg)
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
for i in range(1, 11): for i in range(1, 11):
if i == 0: self.assertEqual(
self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')),
osp.isfile( i % 2 == 0)
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
if i % 2 == 0:
self.assertTrue(
osp.isfile(
osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
# save_last=True
self.assertTrue( self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth'))) osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
self.clear_work_dir()
# Test save_optimizer=False # Test save_optimizer=False
cfg = copy.deepcopy(common_cfg)
runner = self.build_runner(cfg)
runner.train()
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertIn('optimizer', ckpt) self.assertIn('optimizer', ckpt)
cfg.default_hooks.checkpoint.save_optimizer = False cfg.default_hooks.checkpoint.save_optimizer = False
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
...@@ -464,6 +467,7 @@ class TestCheckpointHook(RunnerTestCase): ...@@ -464,6 +467,7 @@ class TestCheckpointHook(RunnerTestCase):
self.assertNotIn('optimizer', ckpt) self.assertNotIn('optimizer', ckpt)
# Test save_param_scheduler=False # Test save_param_scheduler=False
cfg = copy.deepcopy(common_cfg)
cfg.param_scheduler = [ cfg.param_scheduler = [
dict( dict(
type='LinearLR', type='LinearLR',
...@@ -483,7 +487,10 @@ class TestCheckpointHook(RunnerTestCase): ...@@ -483,7 +487,10 @@ class TestCheckpointHook(RunnerTestCase):
ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth'))
self.assertNotIn('param_schedulers', ckpt) self.assertNotIn('param_schedulers', ckpt)
self.clear_work_dir()
# Test out_dir # Test out_dir
cfg = copy.deepcopy(common_cfg)
out_dir = osp.join(self.temp_dir.name, 'out_dir') out_dir = osp.join(self.temp_dir.name, 'out_dir')
cfg.default_hooks.checkpoint.out_dir = out_dir cfg.default_hooks.checkpoint.out_dir = out_dir
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
...@@ -493,37 +500,54 @@ class TestCheckpointHook(RunnerTestCase): ...@@ -493,37 +500,54 @@ class TestCheckpointHook(RunnerTestCase):
osp.join(out_dir, osp.basename(cfg.work_dir), osp.join(out_dir, osp.basename(cfg.work_dir),
f'{training_type}_11.pth'))) f'{training_type}_11.pth')))
# Test max_keep_ckpts. self.clear_work_dir()
del cfg.default_hooks.checkpoint.out_dir
# Test max_keep_ckpts
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.max_keep_ckpts = 1 cfg.default_hooks.checkpoint.max_keep_ckpts = 1
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
print(os.listdir(cfg.work_dir))
self.assertTrue( self.assertTrue(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_10.pth'))) osp.isfile(osp.join(cfg.work_dir, f'{training_type}_11.pth')))
for i in range(10): for i in range(11):
self.assertFalse( self.assertFalse(
osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth')))
self.clear_work_dir()
# Test filename_tmpl # Test filename_tmpl
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth' cfg.default_hooks.checkpoint.filename_tmpl = 'test_{}.pth'
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_10.pth'))) self.assertTrue(osp.isfile(osp.join(cfg.work_dir, 'test_11.pth')))
self.clear_work_dir()
# Test save_best # Test save_best
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.interval = 1
cfg.default_hooks.checkpoint.save_best = 'test/acc' cfg.default_hooks.checkpoint.save_best = 'test/acc'
cfg.val_evaluator = dict(type='TriangleMetric', length=11) cfg.val_evaluator = dict(type='TriangleMetric', length=11)
cfg.train_cfg.val_interval = 1 cfg.train_cfg.val_interval = 1
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
self.assertTrue( best_ckpt = osp.join(cfg.work_dir,
osp.isfile(osp.join(cfg.work_dir, 'best_test_acc_test_5.pth'))) f'best_test_acc_{training_type}_5.pth')
self.assertTrue(osp.isfile(best_ckpt))
self.clear_work_dir()
# test save published keys # test save published keys
cfg = copy.deepcopy(common_cfg)
cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict'] cfg.default_hooks.checkpoint.published_keys = ['meta', 'state_dict']
runner = self.build_runner(cfg) runner = self.build_runner(cfg)
runner.train() runner.train()
ckpt_files = os.listdir(runner.work_dir) ckpt_files = os.listdir(runner.work_dir)
self.assertTrue( self.assertTrue(
any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files)) any(re.findall(r'-[\d\w]{8}\.pth', file) for file in ckpt_files))
self.clear_work_dir()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment