diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 091ced5075494798a258850952abbfe8d6946c52..6098d540de2bd24030a738573e25c625756004be 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -18,17 +18,12 @@ class RuntimeInfoHook(Hook): priority = 'VERY_HIGH' - def before_run(self, runner) -> None: - """Initialize runtime information.""" - runner.message_hub.update_info('epoch', runner.epoch) - runner.message_hub.update_info('iter', runner.iter) - runner.message_hub.update_info('max_epochs', runner.max_epochs) - runner.message_hub.update_info('max_iters', runner.max_iters) - def before_train(self, runner) -> None: """Update resumed training state.""" runner.message_hub.update_info('epoch', runner.epoch) runner.message_hub.update_info('iter', runner.iter) + runner.message_hub.update_info('max_epochs', runner.max_epochs) + runner.message_hub.update_info('max_iters', runner.max_iters) def before_train_epoch(self, runner) -> None: """Update current epoch information before every epoch.""" diff --git a/tests/test_hook/test_runtime_info_hook.py b/tests/test_hook/test_runtime_info_hook.py index b57e26eb529de73a701e42dda819bce467057189..569355470cef642f3cc7b7c503680ca7eded776c 100644 --- a/tests/test_hook/test_runtime_info_hook.py +++ b/tests/test_hook/test_runtime_info_hook.py @@ -12,33 +12,21 @@ from mmengine.optim import OptimWrapper, OptimWrapperDict class TestRuntimeInfoHook(TestCase): - def test_before_run(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_run') - runner = Mock() - runner.epoch = 3 - runner.iter = 30 - runner.max_epochs = 4 - runner.max_iters = 40 - runner.message_hub = message_hub - hook = RuntimeInfoHook() - hook.before_run(runner) - self.assertEqual(message_hub.get_info('epoch'), 3) - self.assertEqual(message_hub.get_info('iter'), 30) - self.assertEqual(message_hub.get_info('max_epochs'), 4) - self.assertEqual(message_hub.get_info('max_iters'), 40) - def test_before_train(self): message_hub = MessageHub.get_instance( 'runtime_info_hook_test_before_train') runner = Mock() runner.epoch = 7 runner.iter = 71 + runner.max_epochs = 4 + runner.max_iters = 40 runner.message_hub = message_hub hook = RuntimeInfoHook() hook.before_train(runner) self.assertEqual(message_hub.get_info('epoch'), 7) self.assertEqual(message_hub.get_info('iter'), 71) + self.assertEqual(message_hub.get_info('max_epochs'), 4) + self.assertEqual(message_hub.get_info('max_iters'), 40) def test_before_train_epoch(self): message_hub = MessageHub.get_instance( diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 7cae59da7393f605ac0d1c3d00fc374a8747963f..a665046032d1f6a07f75ad3767c40a1ac4ad76a8 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1241,6 +1241,8 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_test2' runner = Runner.from_cfg(cfg) runner.test() + # Test run test without building train loop. + self.assertIsInstance(runner._train_loop, dict) # test run test without train and test components cfg = copy.deepcopy(self.epoch_based_cfg)