Skip to content
Snippets Groups Projects
Unverified Commit 8b0c9c5f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] fix build train_loop during test (#295)

* fix build train_loop during test

* fix build train_loop during test

* fix build train_loop during test

* fix build train_loop during test

* Fix as comment
parent 819e10c2
No related branches found
No related tags found
No related merge requests found
...@@ -18,17 +18,12 @@ class RuntimeInfoHook(Hook): ...@@ -18,17 +18,12 @@ class RuntimeInfoHook(Hook):
priority = 'VERY_HIGH' priority = 'VERY_HIGH'
def before_run(self, runner) -> None:
"""Initialize runtime information."""
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
def before_train(self, runner) -> None: def before_train(self, runner) -> None:
"""Update resumed training state.""" """Update resumed training state."""
runner.message_hub.update_info('epoch', runner.epoch) runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter) runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
def before_train_epoch(self, runner) -> None: def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch.""" """Update current epoch information before every epoch."""
......
...@@ -12,33 +12,21 @@ from mmengine.optim import OptimWrapper, OptimWrapperDict ...@@ -12,33 +12,21 @@ from mmengine.optim import OptimWrapper, OptimWrapperDict
class TestRuntimeInfoHook(TestCase): class TestRuntimeInfoHook(TestCase):
def test_before_run(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_run')
runner = Mock()
runner.epoch = 3
runner.iter = 30
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
hook = RuntimeInfoHook()
hook.before_run(runner)
self.assertEqual(message_hub.get_info('epoch'), 3)
self.assertEqual(message_hub.get_info('iter'), 30)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
def test_before_train(self): def test_before_train(self):
message_hub = MessageHub.get_instance( message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train') 'runtime_info_hook_test_before_train')
runner = Mock() runner = Mock()
runner.epoch = 7 runner.epoch = 7
runner.iter = 71 runner.iter = 71
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub runner.message_hub = message_hub
hook = RuntimeInfoHook() hook = RuntimeInfoHook()
hook.before_train(runner) hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7) self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71) self.assertEqual(message_hub.get_info('iter'), 71)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
def test_before_train_epoch(self): def test_before_train_epoch(self):
message_hub = MessageHub.get_instance( message_hub = MessageHub.get_instance(
......
...@@ -1241,6 +1241,8 @@ class TestRunner(TestCase): ...@@ -1241,6 +1241,8 @@ class TestRunner(TestCase):
cfg.experiment_name = 'test_test2' cfg.experiment_name = 'test_test2'
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner.test() runner.test()
# Test run test without building train loop.
self.assertIsInstance(runner._train_loop, dict)
# test run test without train and test components # test run test without train and test components
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment