Newer
Older
@HOOKS.register_module()
class TestWarmupHook(Hook):
"""test custom train loop."""
def before_warmup_iter(self, runner, data_batch=None):
before_warmup_iter_results.append('before')
def after_warmup_iter(self, runner, data_batch=None, outputs=None):
self.iter_based_cfg.train_cfg = dict(
type='CustomTrainLoop2',
max_iters=10,
warmup_loader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='InfiniteSampler', shuffle=True),
batch_size=1,
num_workers=0),
max_warmup_iters=5)
self.iter_based_cfg.custom_hooks = [
dict(type='TestWarmupHook', priority=50)
]
self.iter_based_cfg.experiment_name = 'test_custom_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
# test custom hook triggered as expected
self.assertEqual(len(before_warmup_iter_results), 5)
self.assertEqual(len(after_warmup_iter_results), 5)
for before, after in zip(before_warmup_iter_results,
after_warmup_iter_results):
self.assertEqual(before, 'before')
self.assertEqual(after, 'after')
def test_checkpoint(self):
# 1. test epoch based
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint1'
runner = Runner.from_cfg(cfg)
runner.train()
# 1.1 test `save_checkpoint` which called by `CheckpointHook`
path = osp.join(self.temp_dir, 'epoch_3.pth')
self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
ckpt = torch.load(path)
self.assertEqual(ckpt['meta']['epoch'], 3)
self.assertEqual(ckpt['meta']['iter'], 12)
# self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
# 1.2 test `load_checkpoint`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint2'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded)
# 1.3 test `resume`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint4'
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
cfg.resume = True
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 1.5 test resume from a specified checkpoint
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint5'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 1)
self.assertEqual(runner.iter, 4)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
runner.train()
# 2.1 test `save_checkpoint` which called by `CheckpointHook`
path = osp.join(self.temp_dir, 'iter_12.pth')
self.assertTrue(osp.exists(path))
self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
ckpt = torch.load(path)
self.assertEqual(ckpt['meta']['epoch'], 0)
self.assertEqual(ckpt['meta']['iter'], 12)
# self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
# 2.2 test `load_checkpoint`
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint7'
runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded)
# 2.3 test `resume`
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint8'
runner.resume(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
# 2.4 test auto resume
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint9'
cfg.resume = True
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.5 test resume from a specified checkpoint
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint10'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 3)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)