Skip to content
Snippets Groups Projects
Unverified Commit 6e4bcc99 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix resume from checkpoint. (#174)

parent 798eab48
No related branches found
No related tags found
No related merge requests found
......@@ -1084,7 +1084,13 @@ class Runner:
# decide to load from checkpoint or resume from checkpoint
resume_from = None
if self._resume and self._load_from is None:
# auto resume from the latest checkpoint
resume_from = find_latest_checkpoint(self.work_dir)
self.logger.info(
f'Auto resumed from the latest checkpoint {resume_from}.')
elif self._resume and self._load_from is not None:
# resume from the specified checkpoint
resume_from = self._load_from
if resume_from is not None:
self.resume(resume_from)
......
......@@ -1075,9 +1075,34 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based
# 1.4 test auto resume
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint4'
cfg.resume = True
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 1.5 test resume from a specified checkpoint
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint5'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 1)
self.assertEqual(runner.iter, 4)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
runner = Runner.from_cfg(cfg)
runner.train()
......@@ -1096,7 +1121,7 @@ class TestRunner(TestCase):
# 2.2 test `load_checkpoint`
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint5'
cfg.experiment_name = 'test_checkpoint7'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0)
......@@ -1105,7 +1130,7 @@ class TestRunner(TestCase):
# 2.3 test `resume`
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
cfg.experiment_name = 'test_checkpoint8'
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertEqual(runner.epoch, 0)
......@@ -1113,3 +1138,28 @@ class TestRunner(TestCase):
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.4 test auto resume
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint9'
cfg.resume = True
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.5 test resume from a specified checkpoint
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint10'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 3)
self.assertTrue(runner._has_loaded)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment