Skip to content
Snippets Groups Projects
test_runner.py 43.8 KiB
Newer Older
RangiLyu's avatar
RangiLyu committed

        @HOOKS.register_module()
        class TestWarmupHook(Hook):
            """test custom train loop."""

            def before_warmup_iter(self, runner, data_batch=None):
RangiLyu's avatar
RangiLyu committed
                before_warmup_iter_results.append('before')

            def after_warmup_iter(self, runner, data_batch=None, outputs=None):
RangiLyu's avatar
RangiLyu committed
                after_warmup_iter_results.append('after')

        self.iter_based_cfg.train_cfg = dict(
            type='CustomTrainLoop2',
            max_iters=10,
RangiLyu's avatar
RangiLyu committed
            warmup_loader=dict(
                dataset=dict(type='ToyDataset'),
                sampler=dict(type='InfiniteSampler', shuffle=True),
RangiLyu's avatar
RangiLyu committed
                batch_size=1,
                num_workers=0),
            max_warmup_iters=5)
        self.iter_based_cfg.custom_hooks = [
            dict(type='TestWarmupHook', priority=50)
        ]
        self.iter_based_cfg.experiment_name = 'test_custom_loop'
        runner = Runner.from_cfg(self.iter_based_cfg)
RangiLyu's avatar
RangiLyu committed
        runner.train()

        self.assertIsInstance(runner.train_loop, CustomTrainLoop2)

        # test custom hook triggered as expected
RangiLyu's avatar
RangiLyu committed
        self.assertEqual(len(before_warmup_iter_results), 5)
        self.assertEqual(len(after_warmup_iter_results), 5)
        for before, after in zip(before_warmup_iter_results,
                                 after_warmup_iter_results):
            self.assertEqual(before, 'before')
            self.assertEqual(after, 'after')

    def test_checkpoint(self):
        # 1. test epoch based
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint1'
        runner = Runner.from_cfg(cfg)
        runner.train()

        # 1.1 test `save_checkpoint` which called by `CheckpointHook`
        path = osp.join(self.temp_dir, 'epoch_3.pth')
        self.assertTrue(osp.exists(path))
        self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
        self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))

        ckpt = torch.load(path)
        self.assertEqual(ckpt['meta']['epoch'], 3)
        self.assertEqual(ckpt['meta']['iter'], 12)
        # self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
        assert isinstance(ckpt['optimizer'], dict)
        assert isinstance(ckpt['param_schedulers'], list)

        # 1.2 test `load_checkpoint`
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint2'
        runner = Runner.from_cfg(cfg)
        runner.load_checkpoint(path)
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 0)
        self.assertTrue(runner._has_loaded)

        # 1.3 test `resume`
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint3'
        runner = Runner.from_cfg(cfg)
        runner.resume(path)
        self.assertEqual(runner.epoch, 3)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 1.4 test auto resume
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint4'
        cfg.resume = True
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 3)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 1.5 test resume from a specified checkpoint
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint5'
        cfg.resume = True
        cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 1)
        self.assertEqual(runner.iter, 4)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 2. test iter based
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint6'
        runner = Runner.from_cfg(cfg)
        runner.train()

        # 2.1 test `save_checkpoint` which called by `CheckpointHook`
        path = osp.join(self.temp_dir, 'iter_12.pth')
        self.assertTrue(osp.exists(path))
        self.assertTrue(osp.exists(osp.join(self.temp_dir, 'latest.pth')))
        self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))

        ckpt = torch.load(path)
        self.assertEqual(ckpt['meta']['epoch'], 0)
        self.assertEqual(ckpt['meta']['iter'], 12)
        # self.assertEqual(ckpt['meta']['hook_msgs']['last_ckpt'], path)
        assert isinstance(ckpt['optimizer'], dict)
        assert isinstance(ckpt['param_schedulers'], list)

        # 2.2 test `load_checkpoint`
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint7'
        runner = Runner.from_cfg(cfg)
        runner.load_checkpoint(path)
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 0)
        self.assertTrue(runner._has_loaded)

        # 2.3 test `resume`
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint8'
        runner = Runner.from_cfg(cfg)
        runner.resume(path)
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 2.4 test auto resume
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint9'
        cfg.resume = True
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 2.5 test resume from a specified checkpoint
        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_checkpoint10'
        cfg.resume = True
        cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 3)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)