From 8b0c9c5f6fd3281f5739eab4c9dad41beea55d28 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 13 Jun 2022 21:23:46 +0800 Subject: [PATCH] [Fix] fix build train_loop during test (#295) * fix build train_loop during test * fix build train_loop during test * fix build train_loop during test * fix build train_loop during test * Fix as comment --- mmengine/hooks/runtime_info_hook.py | 9 ++------- tests/test_hook/test_runtime_info_hook.py | 20 ++++---------------- tests/test_runner/test_runner.py | 2 ++ 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 091ced50..6098d540 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -18,17 +18,12 @@ class RuntimeInfoHook(Hook): priority = 'VERY_HIGH' - def before_run(self, runner) -> None: - """Initialize runtime information.""" - runner.message_hub.update_info('epoch', runner.epoch) - runner.message_hub.update_info('iter', runner.iter) - runner.message_hub.update_info('max_epochs', runner.max_epochs) - runner.message_hub.update_info('max_iters', runner.max_iters) - def before_train(self, runner) -> None: """Update resumed training state.""" runner.message_hub.update_info('epoch', runner.epoch) runner.message_hub.update_info('iter', runner.iter) + runner.message_hub.update_info('max_epochs', runner.max_epochs) + runner.message_hub.update_info('max_iters', runner.max_iters) def before_train_epoch(self, runner) -> None: """Update current epoch information before every epoch.""" diff --git a/tests/test_hook/test_runtime_info_hook.py b/tests/test_hook/test_runtime_info_hook.py index b57e26eb..56935547 100644 --- a/tests/test_hook/test_runtime_info_hook.py +++ b/tests/test_hook/test_runtime_info_hook.py @@ -12,33 +12,21 @@ from mmengine.optim import OptimWrapper, OptimWrapperDict class TestRuntimeInfoHook(TestCase): - def test_before_run(self): - message_hub = MessageHub.get_instance( - 'runtime_info_hook_test_before_run') - runner = Mock() - runner.epoch = 3 - runner.iter = 30 - runner.max_epochs = 4 - runner.max_iters = 40 - runner.message_hub = message_hub - hook = RuntimeInfoHook() - hook.before_run(runner) - self.assertEqual(message_hub.get_info('epoch'), 3) - self.assertEqual(message_hub.get_info('iter'), 30) - self.assertEqual(message_hub.get_info('max_epochs'), 4) - self.assertEqual(message_hub.get_info('max_iters'), 40) - def test_before_train(self): message_hub = MessageHub.get_instance( 'runtime_info_hook_test_before_train') runner = Mock() runner.epoch = 7 runner.iter = 71 + runner.max_epochs = 4 + runner.max_iters = 40 runner.message_hub = message_hub hook = RuntimeInfoHook() hook.before_train(runner) self.assertEqual(message_hub.get_info('epoch'), 7) self.assertEqual(message_hub.get_info('iter'), 71) + self.assertEqual(message_hub.get_info('max_epochs'), 4) + self.assertEqual(message_hub.get_info('max_iters'), 40) def test_before_train_epoch(self): message_hub = MessageHub.get_instance( diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 7cae59da..a6650460 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1241,6 +1241,8 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_test2' runner = Runner.from_cfg(cfg) runner.test() + # Test run test without building train loop. + self.assertIsInstance(runner._train_loop, dict) # test run test without train and test components cfg = copy.deepcopy(self.epoch_based_cfg) -- GitLab