Skip to content
Snippets Groups Projects
test_runner.py 85.9 KiB
Newer Older
        # 4. test multiple optimizers and multiple parameter shceduers
        cfg = dict(
            key1=dict(type='MultiStepLR', milestones=[1, 2]),
            key2=[
                dict(type='MultiStepLR', milestones=[1, 2]),
                dict(type='StepLR', step_size=1)
            ])
        param_schedulers = runner.build_param_scheduler(cfg)
        self.assertIsInstance(param_schedulers, dict)
        self.assertEqual(len(param_schedulers), 2)
        self.assertEqual(len(param_schedulers['key1']), 1)
        self.assertEqual(len(param_schedulers['key2']), 2)

        # 5. test converting epoch-based scheduler to iter-based
        runner.optim_wrapper = runner.build_optim_wrapper(
            dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)))

        # 5.1 train loop should be built before converting scheduler
        cfg = dict(
            type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)

        # 5.2 convert epoch-based to iter-based scheduler
        cfg = dict(
            type='MultiStepLR',
            milestones=[1, 2],
            begin=1,
            end=7,
            convert_to_iter_based=True)
        runner._train_loop = runner.build_train_loop(runner.train_loop)
        param_schedulers = runner.build_param_scheduler(cfg)
        self.assertFalse(param_schedulers[0].by_epoch)
        self.assertEqual(param_schedulers[0].begin, 4)
        self.assertEqual(param_schedulers[0].end, 28)

        # 6. test set default end of schedulers
        cfg = dict(type='MultiStepLR', milestones=[1, 2], begin=1)
        param_schedulers = runner.build_param_scheduler(cfg)
        self.assertTrue(param_schedulers[0].by_epoch)
        self.assertEqual(param_schedulers[0].begin, 1)
        # runner.max_epochs = 3
        self.assertEqual(param_schedulers[0].end, 3)

        cfg = dict(
            type='MultiStepLR',
            milestones=[1, 2],
            begin=1,
            convert_to_iter_based=True)
        param_schedulers = runner.build_param_scheduler(cfg)
        self.assertFalse(param_schedulers[0].by_epoch)
        self.assertEqual(param_schedulers[0].begin, 4)
        # runner.max_iters = 3*4
        self.assertEqual(param_schedulers[0].end, 12)

    def test_build_evaluator(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_evaluator'
        runner = Runner.from_cfg(cfg)

        # input is a BaseEvaluator or ComposedEvaluator object
        evaluator = Evaluator(ToyMetric1())
        self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))

        evaluator = Evaluator([ToyMetric1(), ToyMetric2()])
        self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))

        # input is a dict
        evaluator = dict(type='ToyMetric1')
        self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)

        # input is a list of dict
        evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
        self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
        # test collect device
        evaluator = [
            dict(type='ToyMetric1', collect_device='cpu'),
            dict(type='ToyMetric2', collect_device='gpu')
        ]
        _evaluator = runner.build_evaluator(evaluator)
        self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
        self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
        # test build a customize evaluator
        evaluator = dict(
            type='ToyEvaluator',
            metrics=[
                dict(type='ToyMetric1', collect_device='cpu'),
                dict(type='ToyMetric2', collect_device='gpu')
            ])
        _evaluator = runner.build_evaluator(evaluator)
        self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator)
        self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
        self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')

    def test_build_dataloader(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_dataloader'
        runner = Runner.from_cfg(cfg)

        cfg = dict(
            dataset=dict(type='ToyDataset'),
            sampler=dict(type='DefaultSampler', shuffle=True),
            batch_size=1,
            num_workers=0)
        seed = np.random.randint(2**31)
        dataloader = runner.build_dataloader(cfg, seed=seed)
        self.assertIsInstance(dataloader, DataLoader)
        self.assertIsInstance(dataloader.dataset, ToyDataset)
        self.assertIsInstance(dataloader.sampler, DefaultSampler)
        self.assertEqual(dataloader.sampler.seed, seed)
        # diff_rank_seed is True
        dataloader = runner.build_dataloader(
            cfg, seed=seed, diff_rank_seed=True)
        self.assertNotEqual(dataloader.sampler.seed, seed)

    def test_build_train_loop(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_train_loop'
        runner = Runner.from_cfg(cfg)

        # input should be a Loop object or dict
        with self.assertRaisesRegex(TypeError, 'should be'):
            runner.build_train_loop('invalid-type')

        # Only one of type or by_epoch can exist in cfg
        cfg = dict(type='EpochBasedTrainLoop', by_epoch=True, max_epochs=3)
        with self.assertRaisesRegex(RuntimeError, 'Only one'):
            runner.build_train_loop(cfg)

        # input is a dict and contains type key
        cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, EpochBasedTrainLoop)

        cfg = dict(type='IterBasedTrainLoop', max_iters=3)
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, IterBasedTrainLoop)

        # input is a dict and does not contain type key
        cfg = dict(by_epoch=True, max_epochs=3)
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, EpochBasedTrainLoop)

        cfg = dict(by_epoch=False, max_iters=3)
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, IterBasedTrainLoop)

        # input is a Loop object
        self.assertEqual(id(runner.build_train_loop(loop)), id(loop))

        # param_schedulers can be None
        cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
        runner.param_schedulers = None
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, EpochBasedTrainLoop)

        # test custom training loop
        cfg = dict(type='CustomTrainLoop', max_epochs=3)
        loop = runner.build_train_loop(cfg)
        self.assertIsInstance(loop, CustomTrainLoop)

    def test_build_val_loop(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_val_loop'
        runner = Runner.from_cfg(cfg)

        # input should be a Loop object or dict
        with self.assertRaisesRegex(TypeError, 'should be'):
            runner.build_test_loop('invalid-type')

        # input is a dict and contains type key
        loop = runner.build_test_loop(cfg)
        self.assertIsInstance(loop, ValLoop)

        # input is a dict but does not contain type key
        loop = runner.build_val_loop(cfg)
        self.assertIsInstance(loop, ValLoop)

        # input is a Loop object
        self.assertEqual(id(runner.build_val_loop(loop)), id(loop))

        # test custom validation loop
        loop = runner.build_val_loop(cfg)
        self.assertIsInstance(loop, CustomValLoop)

    def test_build_test_loop(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_test_loop'
        runner = Runner.from_cfg(cfg)

        # input should be a Loop object or dict
        with self.assertRaisesRegex(TypeError, 'should be'):
            runner.build_test_loop('invalid-type')

        # input is a dict and contains type key
        cfg = dict(type='TestLoop')
        loop = runner.build_test_loop(cfg)
        self.assertIsInstance(loop, TestLoop)

        # input is a dict but does not contain type key
        cfg = dict()
        loop = runner.build_test_loop(cfg)
        self.assertIsInstance(loop, TestLoop)

        # input is a Loop object
        self.assertEqual(id(runner.build_test_loop(loop)), id(loop))

        # test custom validation loop
        cfg = dict(type='CustomTestLoop')
        loop = runner.build_val_loop(cfg)
        self.assertIsInstance(loop, CustomTestLoop)

    def test_build_log_processor(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_build_log_processor'
        runner = Runner.from_cfg(cfg)

        # input should be a LogProcessor object or dict
        with self.assertRaisesRegex(TypeError, 'should be'):
            runner.build_log_processor('invalid-type')

        # input is a dict and contains type key
        cfg = dict(type='LogProcessor')
        log_processor = runner.build_log_processor(cfg)
        self.assertIsInstance(log_processor, LogProcessor)

        # input is a dict but does not contain type key
        cfg = dict()
        log_processor = runner.build_log_processor(cfg)
        self.assertIsInstance(log_processor, LogProcessor)

        # input is a LogProcessor object
        self.assertEqual(
            id(runner.build_log_processor(log_processor)), id(log_processor))

        # test custom validation log_processor
        cfg = dict(type='CustomLogProcessor')
        log_processor = runner.build_log_processor(cfg)
        self.assertIsInstance(log_processor, CustomLogProcessor)

    def test_train(self):
        # 1. test `self.train_loop` is None
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train1'
        cfg.pop('train_dataloader')
        cfg.pop('train_cfg')
        cfg.pop('optim_wrapper')
        cfg.pop('param_scheduler')
        runner = Runner.from_cfg(cfg)
        with self.assertRaisesRegex(RuntimeError, 'should not be None'):
            runner.train()
        # 2. test iter and epoch counter of EpochBasedTrainLoop and timing of
        # running ValLoop
RangiLyu's avatar
RangiLyu committed
        epoch_results = []
        epoch_targets = [i for i in range(3)]
RangiLyu's avatar
RangiLyu committed
        iter_results = []
        iter_targets = [i for i in range(4 * 3)]
        batch_idx_results = []
        batch_idx_targets = [i for i in range(4)] * 3  # train and val
        val_epoch_results = []
        val_epoch_targets = [i for i in range(2, 4)]
RangiLyu's avatar
RangiLyu committed

        @HOOKS.register_module()
        class TestEpochHook(Hook):
RangiLyu's avatar
RangiLyu committed

            def before_train_epoch(self, runner):
                epoch_results.append(runner.epoch)

            def before_train_iter(self, runner, batch_idx, data_batch=None):
RangiLyu's avatar
RangiLyu committed
                iter_results.append(runner.iter)
                batch_idx_results.append(batch_idx)
            def before_val_epoch(self, runner):
                val_epoch_results.append(runner.epoch)

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train2'
        cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
        cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
        runner = Runner.from_cfg(cfg)
RangiLyu's avatar
RangiLyu committed
        runner.train()
        self.assertEqual(runner.optim_wrapper._inner_count, 12)
        self.assertEqual(runner.optim_wrapper._max_counts, 12)
        assert isinstance(runner.train_loop, EpochBasedTrainLoop)

        for result, target, in zip(epoch_results, epoch_targets):
            self.assertEqual(result, target)
RangiLyu's avatar
RangiLyu committed
        for result, target, in zip(iter_results, iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(batch_idx_results, batch_idx_targets):
RangiLyu's avatar
RangiLyu committed
            self.assertEqual(result, target)
        for result, target, in zip(val_epoch_results, val_epoch_targets):
            self.assertEqual(result, target)
        # 3. test iter and epoch counter of IterBasedTrainLoop and timing of
        # running ValLoop
RangiLyu's avatar
RangiLyu committed
        epoch_results = []
        iter_results = []
        batch_idx_results = []
        val_iter_results = []
        val_batch_idx_results = []
        iter_targets = [i for i in range(12)]
        batch_idx_targets = [i for i in range(12)]
        val_iter_targets = [i for i in range(4, 12)]
        val_batch_idx_targets = [i for i in range(4)] * 2
RangiLyu's avatar
RangiLyu committed

        @HOOKS.register_module()
        class TestIterHook(Hook):
RangiLyu's avatar
RangiLyu committed

            def before_train_epoch(self, runner):
                epoch_results.append(runner.epoch)

            def before_train_iter(self, runner, batch_idx, data_batch=None):
RangiLyu's avatar
RangiLyu committed
                iter_results.append(runner.iter)
                batch_idx_results.append(batch_idx)
            def before_val_iter(self, runner, batch_idx, data_batch=None):
                val_epoch_results.append(runner.iter)
                val_batch_idx_results.append(batch_idx)

        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_train3'
        cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
        cfg.train_cfg = dict(
            by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
        runner = Runner.from_cfg(cfg)
RangiLyu's avatar
RangiLyu committed
        runner.train()

        self.assertEqual(runner.optim_wrapper._inner_count, 12)
        self.assertEqual(runner.optim_wrapper._max_counts, 12)
        assert isinstance(runner.train_loop, IterBasedTrainLoop)

        self.assertEqual(len(epoch_results), 1)
        self.assertEqual(epoch_results[0], 0)
        self.assertEqual(runner.val_interval, 4)
        self.assertEqual(runner.val_begin, 4)
        for result, target, in zip(iter_results, iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(batch_idx_results, batch_idx_targets):
            self.assertEqual(result, target)
        for result, target, in zip(val_iter_results, val_iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(val_batch_idx_results,
                                   val_batch_idx_targets):
            self.assertEqual(result, target)

        # 4. test iter and epoch counter of IterBasedTrainLoop and timing of
        # running ValLoop without InfiniteSampler
        epoch_results = []
        iter_results = []
        batch_idx_results = []
        val_iter_results = []
        val_batch_idx_results = []
        iter_targets = [i for i in range(12)]
        batch_idx_targets = [i for i in range(12)]
        val_iter_targets = [i for i in range(4, 12)]
        val_batch_idx_targets = [i for i in range(4)] * 2

        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_train4'
        cfg.train_dataloader.sampler = dict(
            type='DefaultSampler', shuffle=True)
        cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
        cfg.train_cfg = dict(
            by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
        runner = Runner.from_cfg(cfg)
        with self.assertWarnsRegex(
                Warning,
                'Reach the end of the dataloader, it will be restarted and '
                'continue to iterate.'):
            runner.train()

        assert isinstance(runner.train_loop, IterBasedTrainLoop)
        assert isinstance(runner.train_loop.dataloader_iterator,
                          _InfiniteDataloaderIterator)

        self.assertEqual(len(epoch_results), 1)
        self.assertEqual(epoch_results[0], 0)
        self.assertEqual(runner.val_interval, 4)
        self.assertEqual(runner.val_begin, 4)
RangiLyu's avatar
RangiLyu committed
        for result, target, in zip(iter_results, iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(batch_idx_results, batch_idx_targets):
RangiLyu's avatar
RangiLyu committed
            self.assertEqual(result, target)
        for result, target, in zip(val_iter_results, val_iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(val_batch_idx_results,
                                   val_batch_idx_targets):
            self.assertEqual(result, target)
        # 5. test dynamic interval in IterBasedTrainLoop
        max_iters = 12
        interval = 5
        dynamic_intervals = [(11, 2)]
        iter_results = []
        iter_targets = [5, 10, 12]
        val_interval_results = []
        val_interval_targets = [5] * 10 + [2] * 2

        @HOOKS.register_module()
        class TestIterDynamicIntervalHook(Hook):

            def before_val(self, runner):
                iter_results.append(runner.iter)

            def before_train_iter(self, runner, batch_idx, data_batch=None):
                val_interval_results.append(runner.train_loop.val_interval)

        cfg = copy.deepcopy(self.iter_based_cfg)
        cfg.experiment_name = 'test_train5'
        cfg.train_dataloader.sampler = dict(
            type='DefaultSampler', shuffle=True)
        cfg.custom_hooks = [
            dict(type='TestIterDynamicIntervalHook', priority=50)
        ]
        cfg.train_cfg = dict(
            by_epoch=False,
            max_iters=max_iters,
            val_interval=interval,
            dynamic_intervals=dynamic_intervals)
        runner = Runner.from_cfg(cfg)
        runner.train()
        for result, target, in zip(iter_results, iter_targets):
            self.assertEqual(result, target)
        for result, target, in zip(val_interval_results, val_interval_targets):
            self.assertEqual(result, target)

        # 6. test dynamic interval in EpochBasedTrainLoop
        max_epochs = 12
        interval = 5
        dynamic_intervals = [(11, 2)]
        epoch_results = []
        epoch_targets = [5, 10, 12]
        val_interval_results = []
        val_interval_targets = [5] * 10 + [2] * 2

        @HOOKS.register_module()
        class TestEpochDynamicIntervalHook(Hook):

            def before_val_epoch(self, runner):
                epoch_results.append(runner.epoch)

            def before_train_epoch(self, runner):
                val_interval_results.append(runner.train_loop.val_interval)

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train6'
        cfg.train_dataloader.sampler = dict(
            type='DefaultSampler', shuffle=True)
        cfg.custom_hooks = [
            dict(type='TestEpochDynamicIntervalHook', priority=50)
        ]
        cfg.train_cfg = dict(
            by_epoch=True,
            max_epochs=max_epochs,
            val_interval=interval,
            dynamic_intervals=dynamic_intervals)
        runner = Runner.from_cfg(cfg)
        runner.train()
        for result, target, in zip(epoch_results, epoch_targets):
            self.assertEqual(result, target)
        for result, target, in zip(val_interval_results, val_interval_targets):
            self.assertEqual(result, target)

        # 7. test init weights
        @MODELS.register_module()
        class ToyModel2(ToyModel):

            def __init__(self):
                super().__init__()
                self.initiailzed = False

            def init_weights(self):
                self.initiailzed = True

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train7'
        runner = Runner.from_cfg(cfg)
        model = ToyModel2()
        runner.model = model
        runner.train()
        self.assertTrue(model.initiailzed)

        # 8.1 test train with multiple optimizer and single list of schedulers.
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train8'
        cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2])
        cfg.optim_wrapper = dict(
            linear1=dict(
                type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
            linear2=dict(
                type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
            constructor='ToyMultipleOptimizerConstructor')
        cfg.model = dict(type='ToyGANModel')
        runner = runner.from_cfg(cfg)
        runner.train()

        # 8.1 Test train with multiple optimizer and single schedulers.
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_train8.1.1'
        cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2])
        cfg.optim_wrapper = dict(
            linear1=dict(
                type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
            linear2=dict(
                type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
            constructor='ToyMultipleOptimizerConstructor')
        cfg.model = dict(type='ToyGANModel')
        runner = runner.from_cfg(cfg)
        runner.train()

        # Test list like single scheduler.
        cfg.experiment_name = 'test_train8.1.2'
        cfg.param_scheduler = [dict(type='MultiStepLR', milestones=[1, 2])]
        runner = runner.from_cfg(cfg)
        runner.train()

        # 8.2 Test train with multiple optimizer and multiple schedulers.
        cfg.experiment_name = 'test_train8.2.1'
        cfg.param_scheduler = dict(
            linear1=dict(type='MultiStepLR', milestones=[1, 2]),
            linear2=dict(type='MultiStepLR', milestones=[1, 2]),
        )
        runner = runner.from_cfg(cfg)
        runner.train()

        cfg.experiment_name = 'test_train8.2.2'
        cfg.param_scheduler = dict(
            linear1=[dict(type='MultiStepLR', milestones=[1, 2])],
            linear2=[dict(type='MultiStepLR', milestones=[1, 2])],
        )
        runner = runner.from_cfg(cfg)
        runner.train()

        # 9 Test training with a dataset without metainfo
Mashiro's avatar
Mashiro committed
        cfg.experiment_name = 'test_train9'
        cfg = copy.deepcopy(cfg)
        cfg.train_dataloader.dataset = dict(type='ToyDatasetNoMeta')
        runner = runner.from_cfg(cfg)
        runner.train()

    def test_val(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_val1'
        cfg.pop('val_dataloader')
        cfg.pop('val_cfg')
        cfg.pop('val_evaluator')
        runner = Runner.from_cfg(cfg)
        with self.assertRaisesRegex(RuntimeError, 'should not be None'):
            runner.val()

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_val2'
        runner = Runner.from_cfg(cfg)
        # test run val without train and test components
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_individually_val'
        cfg.pop('train_dataloader')
        cfg.pop('train_cfg')
        cfg.pop('optim_wrapper')
        cfg.pop('param_scheduler')
        cfg.pop('test_dataloader')
        cfg.pop('test_cfg')
        cfg.pop('test_evaluator')
        runner = Runner.from_cfg(cfg)

        # Test default fp32 `autocast` context.
        predictions = []

        def get_outputs_callback(module, inputs, outputs):
            predictions.append(outputs)

        runner.model.register_forward_hook(get_outputs_callback)
        self.assertEqual(predictions[0].dtype, torch.float32)
        predictions.clear()

        # Test fp16 `autocast` context.
        cfg.experiment_name = 'test_val3'
        cfg.val_cfg = dict(fp16=True)
        runner = Runner.from_cfg(cfg)
        runner.model.register_forward_hook(get_outputs_callback)
        if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
                and not torch.cuda.is_available()):
            with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
                runner.val()
        else:
            runner.val()
            self.assertIn(predictions[0].dtype,
                          (torch.float16, torch.bfloat16))
    def test_test(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_test1'
        cfg.pop('test_dataloader')
        cfg.pop('test_cfg')
        cfg.pop('test_evaluator')
        runner = Runner.from_cfg(cfg)
        with self.assertRaisesRegex(RuntimeError, 'should not be None'):
            runner.test()

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_test2'
        runner = Runner.from_cfg(cfg)
        runner.test()
        # Test run test without building train loop.
        self.assertIsInstance(runner._train_loop, dict)
        # test run test without train and test components
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_individually_test'
        cfg.pop('train_dataloader')
        cfg.pop('train_cfg')
        cfg.pop('optim_wrapper')
        cfg.pop('param_scheduler')
        cfg.pop('val_dataloader')
        cfg.pop('val_cfg')
        cfg.pop('val_evaluator')
        runner = Runner.from_cfg(cfg)

        # Test default fp32 `autocast` context.
        predictions = []

        def get_outputs_callback(module, inputs, outputs):
            predictions.append(outputs)

        runner.model.register_forward_hook(get_outputs_callback)
        self.assertEqual(predictions[0].dtype, torch.float32)
        predictions.clear()

        # Test fp16 `autocast` context.
        cfg.experiment_name = 'test_val3'
        cfg.test_cfg = dict(fp16=True)
        runner = Runner.from_cfg(cfg)
        runner.model.register_forward_hook(get_outputs_callback)
        if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
                and not torch.cuda.is_available()):
            with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
                runner.test()
        else:
            runner.test()
            self.assertIn(predictions[0].dtype,
                          (torch.float16, torch.bfloat16))
    def test_register_hook(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_register_hook'
        runner = Runner.from_cfg(cfg)
        runner._hooks = []

        # 1. test `hook` parameter
        # 1.1 `hook` should be either a Hook object or dict
        with self.assertRaisesRegex(
                TypeError, 'hook should be an instance of Hook or dict'):
            runner.register_hook(['string'])

        # 1.2 `hook` is a dict
        timer_cfg = dict(type='IterTimerHook')
        runner.register_hook(timer_cfg)
        self.assertEqual(len(runner._hooks), 1)
        self.assertTrue(isinstance(runner._hooks[0], IterTimerHook))
        # default priority of `IterTimerHook` is 'NORMAL'
        self.assertEqual(
            get_priority(runner._hooks[0].priority), get_priority('NORMAL'))

        runner._hooks = []
        # 1.2.1 `hook` is a dict and contains `priority` field
        # set the priority of `IterTimerHook` as 'BELOW_NORMAL'
        timer_cfg = dict(type='IterTimerHook', priority='BELOW_NORMAL')
        runner.register_hook(timer_cfg)
        self.assertEqual(len(runner._hooks), 1)
        self.assertTrue(isinstance(runner._hooks[0], IterTimerHook))
        self.assertEqual(
            get_priority(runner._hooks[0].priority),
            get_priority('BELOW_NORMAL'))

        # 1.3 `hook` is a hook object
        runtime_info_hook = RuntimeInfoHook()
        runner.register_hook(runtime_info_hook)
        self.assertEqual(len(runner._hooks), 2)
        # The priority of `runtime_info_hook` is `HIGH` which is greater than
        # `IterTimerHook`, so the first item of `_hooks` should be
        # `runtime_info_hook`
        self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook))
        self.assertEqual(
            get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH'))

        # 2. test `priority` parameter
        # `priority` argument is not None and it will be set as priority of
        # hook
        param_scheduler_cfg = dict(type='ParamSchedulerHook', priority='LOW')
        runner.register_hook(param_scheduler_cfg, priority='VERY_LOW')
        self.assertEqual(len(runner._hooks), 3)
        self.assertTrue(isinstance(runner._hooks[2], ParamSchedulerHook))
        self.assertEqual(
            get_priority(runner._hooks[2].priority), get_priority('VERY_LOW'))

        # `priority` is Priority
        logger_cfg = dict(type='LoggerHook', priority='BELOW_NORMAL')
        runner.register_hook(logger_cfg, priority=Priority.VERY_LOW)
        self.assertEqual(len(runner._hooks), 4)
        self.assertTrue(isinstance(runner._hooks[3], LoggerHook))
        self.assertEqual(
            get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))

    def test_default_hooks(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_default_hooks'
        runner = Runner.from_cfg(cfg)
        runner._hooks = []

        runner.register_default_hooks()
        self.assertEqual(len(runner._hooks), 6)
        # the third registered hook should be `DistSamplerSeedHook`
        self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
        # the fifth registered hook should be `ParamSchedulerHook`
        self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))

        runner._hooks = []
        # remove `ParamSchedulerHook` from default hooks
        runner.register_default_hooks(hooks=dict(timer=None))
        self.assertEqual(len(runner._hooks), 5)
        # `ParamSchedulerHook` was popped so the fifth is `CheckpointHook`
        self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))

        # add a new default hook
        runner._hooks = []
        runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
        self.assertEqual(len(runner._hooks), 7)
        self.assertTrue(isinstance(runner._hooks[6], ToyHook))

    def test_custom_hooks(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_custom_hooks'
        runner = Runner.from_cfg(cfg)

        self.assertEqual(len(runner._hooks), 6)
        custom_hooks = [dict(type='ToyHook')]
        runner.register_custom_hooks(custom_hooks)
        self.assertEqual(len(runner._hooks), 7)
        self.assertTrue(isinstance(runner._hooks[6], ToyHook))

    def test_register_hooks(self):
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_register_hooks'
        runner = Runner.from_cfg(cfg)

        runner._hooks = []
        custom_hooks = [dict(type='ToyHook')]
        runner.register_hooks(custom_hooks=custom_hooks)
        # six default hooks + custom hook (ToyHook)
        self.assertEqual(len(runner._hooks), 7)
        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
RangiLyu's avatar
RangiLyu committed
    def test_custom_loop(self):
        # test custom loop with additional hook
        @LOOPS.register_module()
        class CustomTrainLoop2(IterBasedTrainLoop):
            """Custom train loop with additional warmup stage."""
            def __init__(self, runner, dataloader, max_iters, warmup_loader,
RangiLyu's avatar
RangiLyu committed
                         max_warmup_iters):
                super().__init__(
                    runner=runner, dataloader=dataloader, max_iters=max_iters)
RangiLyu's avatar
RangiLyu committed
                self.warmup_loader = self.runner.build_dataloader(
                    warmup_loader)
                self.max_warmup_iters = max_warmup_iters

            def run(self):
                self.runner.call_hook('before_train')
                self.runner.cur_dataloader = self.warmup_loader
                for idx, data_batch in enumerate(self.warmup_loader, 1):
RangiLyu's avatar
RangiLyu committed
                    self.warmup_iter(data_batch)
                    if idx == self.max_warmup_iters:
                self.runner.cur_dataloader = self.warmup_loader
                self.runner.call_hook('before_train_epoch')
                while self.runner.iter < self._max_iters:
                    data_batch = next(self.dataloader_iterator)
RangiLyu's avatar
RangiLyu committed
                    self.run_iter(data_batch)
                self.runner.call_hook('after_train_epoch')

                self.runner.call_hook('after_train')
RangiLyu's avatar
RangiLyu committed

            def warmup_iter(self, data_batch):
                self.runner.call_hook(
                    'before_warmup_iter', data_batch=data_batch)
                train_logs = self.runner.model.train_step(
                    data_batch, self.runner.optim_wrapper)
                self.runner.message_hub.update_info('train_logs', train_logs)
                self.runner.call_hook(
                    'after_warmup_iter', data_batch=data_batch)
RangiLyu's avatar
RangiLyu committed

        before_warmup_iter_results = []
        after_warmup_iter_results = []

        @HOOKS.register_module()
        class TestWarmupHook(Hook):
            """test custom train loop."""

            def before_warmup_iter(self, runner, data_batch=None):
RangiLyu's avatar
RangiLyu committed
                before_warmup_iter_results.append('before')

            def after_warmup_iter(self, runner, data_batch=None, outputs=None):
RangiLyu's avatar
RangiLyu committed
                after_warmup_iter_results.append('after')

        self.iter_based_cfg.train_cfg = dict(
            type='CustomTrainLoop2',
            max_iters=10,
RangiLyu's avatar
RangiLyu committed
            warmup_loader=dict(
                dataset=dict(type='ToyDataset'),
                sampler=dict(type='InfiniteSampler', shuffle=True),
RangiLyu's avatar
RangiLyu committed
                batch_size=1,
                num_workers=0),
            max_warmup_iters=5)
        self.iter_based_cfg.custom_hooks = [
            dict(type='TestWarmupHook', priority=50)
        ]
        self.iter_based_cfg.experiment_name = 'test_custom_loop'
        runner = Runner.from_cfg(self.iter_based_cfg)
RangiLyu's avatar
RangiLyu committed
        runner.train()

        self.assertIsInstance(runner.train_loop, CustomTrainLoop2)

        # test custom hook triggered as expected
RangiLyu's avatar
RangiLyu committed
        self.assertEqual(len(before_warmup_iter_results), 5)
        self.assertEqual(len(after_warmup_iter_results), 5)
        for before, after in zip(before_warmup_iter_results,
                                 after_warmup_iter_results):
            self.assertEqual(before, 'before')
            self.assertEqual(after, 'after')

    def test_checkpoint(self):
        # 1. test epoch based
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint1'
        runner = Runner.from_cfg(cfg)
        # 1.1 test `save_checkpoint` which is called by `CheckpointHook`
        path = osp.join(self.temp_dir, 'epoch_3.pth')
        self.assertTrue(osp.exists(path))
        self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))

        ckpt = torch.load(path)
        self.assertEqual(ckpt['meta']['epoch'], 3)
        self.assertEqual(ckpt['meta']['iter'], 12)
        self.assertEqual(ckpt['meta']['dataset_meta'],
                         runner.train_dataloader.dataset.metainfo)
        self.assertEqual(ckpt['meta']['experiment_name'],
                         runner.experiment_name)
        self.assertEqual(ckpt['meta']['seed'], runner.seed)
        assert isinstance(ckpt['optimizer'], dict)
        assert isinstance(ckpt['param_schedulers'], list)
        self.assertIsInstance(ckpt['message_hub'], dict)
        message_hub = MessageHub.get_instance('test_ckpt')
        message_hub.load_state_dict(ckpt['message_hub'])
        self.assertEqual(message_hub.get_info('epoch'), 2)
        self.assertEqual(message_hub.get_info('iter'), 11)

        # 1.2 test `load_checkpoint`
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint2'
        cfg.optim_wrapper = dict(type='SGD', lr=0.2)
        cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
        runner = Runner.from_cfg(cfg)
        runner.load_checkpoint(path)
        self.assertEqual(runner.epoch, 0)
        self.assertEqual(runner.iter, 0)
        self.assertTrue(runner._has_loaded)
        # load checkpoint will not initialize optimizer and param_schedulers
        # objects
        self.assertIsInstance(runner.optim_wrapper, dict)
        self.assertIsInstance(runner.param_schedulers, dict)
        # 1.3.1 test `resume`
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint3'
        cfg.optim_wrapper = dict(
            type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2))
        cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
        runner = Runner.from_cfg(cfg)
        runner.resume(path)
        self.assertEqual(runner.epoch, 3)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
        self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
        self.assertEqual(runner.optim_wrapper.param_groups[0]['lr'], 0.0001)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
        self.assertEqual(runner.param_schedulers[0].milestones, {1: 1, 2: 1})
        self.assertIsInstance(runner.message_hub, MessageHub)
        self.assertEqual(runner.message_hub.get_info('epoch'), 2)
        self.assertEqual(runner.message_hub.get_info('iter'), 11)
        self.assertEqual(MessageHub.get_current_instance().get_info('epoch'),
                         2)
        self.assertEqual(MessageHub.get_current_instance().get_info('iter'),
                         11)
        # 1.3.2 test resume with unmatched dataset_meta
        ckpt_modified = copy.deepcopy(ckpt)
        ckpt_modified['meta']['dataset_meta'] = {'CLASSES': ['cat', 'dog']}
        # ckpt_modified['meta']['seed'] = 123
        path_modified = osp.join(self.temp_dir, 'modified.pth')
        torch.save(ckpt_modified, path_modified)
        with self.assertWarnsRegex(
                Warning, 'The dataset metainfo from the resumed checkpoint is '
                'different from the current training dataset, please '
                'check the correctness of the checkpoint or the training '
                'dataset.'):
            runner.resume(path_modified)

        # 1.3.3 test resume with unmatched seed
        ckpt_modified = copy.deepcopy(ckpt)
        ckpt_modified['meta']['seed'] = 123
        path_modified = osp.join(self.temp_dir, 'modified.pth')
        torch.save(ckpt_modified, path_modified)
        with self.assertWarnsRegex(
                Warning, 'The value of random seed in the checkpoint'):
            runner.resume(path_modified)

        # 1.3.3 test resume with no seed and dataset meta
        ckpt_modified = copy.deepcopy(ckpt)
        ckpt_modified['meta'].pop('seed')
        ckpt_modified['meta'].pop('dataset_meta')
        path_modified = osp.join(self.temp_dir, 'modified.pth')
        torch.save(ckpt_modified, path_modified)
        runner.resume(path_modified)

        # 1.4 test auto resume
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint4'
        cfg.resume = True
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 3)
        self.assertEqual(runner.iter, 12)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 1.5 test resume from a specified checkpoint
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint5'
        cfg.resume = True
        cfg.load_from = osp.join(self.temp_dir, 'epoch_1.pth')
        runner = Runner.from_cfg(cfg)
        runner.load_or_resume()
        self.assertEqual(runner.epoch, 1)
        self.assertEqual(runner.iter, 4)
        self.assertTrue(runner._has_loaded)
        self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
        self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)

        # 1.6 multiple optimizers
        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint6'
        cfg.optim_wrapper = dict(
            linear1=dict(
                type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)),
            linear2=dict(
                type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
            constructor='ToyMultipleOptimizerConstructor')
        cfg.model = dict(type='ToyGANModel')
        # disable OptimizerHook because it only works with one optimizer
        runner = Runner.from_cfg(cfg)
        runner.train()
        path = osp.join(self.temp_dir, 'epoch_3.pth')
        self.assertTrue(osp.exists(path))
        self.assertEqual(runner.optim_wrapper['linear1'].param_groups[0]['lr'],
        self.assertIsInstance(runner.optim_wrapper['linear2'].optimizer, Adam)
        self.assertEqual(runner.optim_wrapper['linear2'].param_groups[0]['lr'],
                         0.0002)

        cfg = copy.deepcopy(self.epoch_based_cfg)
        cfg.experiment_name = 'test_checkpoint7'
        cfg.optim_wrapper = dict(
            linear1=dict(
                type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)),
            linear2=dict(
                type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)),
            constructor='ToyMultipleOptimizerConstructor')
        cfg.model = dict(type='ToyGANModel')
        cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
        runner = Runner.from_cfg(cfg)
        runner.resume(path)
        self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict)
        self.assertIsInstance(runner.optim_wrapper['linear1'].optimizer, SGD)