Skip to content
Snippets Groups Projects
Unverified Commit a706bbc0 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix]: fix error and add unit test (#429)

parent f5cb45dc
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,7 @@ class RuntimeInfoHook(Hook):
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
if hasattr(runner.train_dataloader.dataset, 'dataset_meta'):
if hasattr(runner.train_dataloader.dataset, 'metainfo'):
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
......
......@@ -15,18 +15,32 @@ class TestRuntimeInfoHook(TestCase):
def test_before_train(self):
message_hub = MessageHub.get_instance(
'runtime_info_hook_test_before_train')
class ToyDataset:
...
runner = Mock()
runner.epoch = 7
runner.iter = 71
runner.max_epochs = 4
runner.max_iters = 40
runner.message_hub = message_hub
runner.train_dataloader.dataset = ToyDataset()
hook = RuntimeInfoHook()
hook.before_train(runner)
self.assertEqual(message_hub.get_info('epoch'), 7)
self.assertEqual(message_hub.get_info('iter'), 71)
self.assertEqual(message_hub.get_info('max_epochs'), 4)
self.assertEqual(message_hub.get_info('max_iters'), 40)
with self.assertRaisesRegex(KeyError, 'dataset_meta is not found'):
message_hub.get_info('dataset_meta')
class ToyDatasetWithMeta:
metainfo = dict()
runner.train_dataloader.dataset = ToyDatasetWithMeta()
hook.before_train(runner)
self.assertEqual(message_hub.get_info('dataset_meta'), dict())
def test_before_train_epoch(self):
message_hub = MessageHub.get_instance(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment