Skip to content
Snippets Groups Projects
Unverified Commit 45567b1d authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

automaticaly update iter and epoch in message_hub (#168)

* automatic update iter and epoch in message_hub

* add docstring

* Update comment and docstring

* Fix as comment

* Fix docstring and comment

* refine comments
parent 53101a1a
No related branches found
No related tags found
No related merge requests found
...@@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_iter(idx, data_batch) self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train_epoch')
self.runner._epoch += 1 self.runner.epoch += 1
def run_iter(self, idx, def run_iter(self, idx,
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
...@@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop): ...@@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self.runner._iter += 1 self.runner.iter += 1
@LOOPS.register_module() @LOOPS.register_module()
...@@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop): ...@@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop):
batch_idx=self.runner._iter, batch_idx=self.runner._iter,
data_batch=data_batch, data_batch=data_batch,
outputs=self.runner.outputs) outputs=self.runner.outputs)
self.runner._iter += 1 self.runner.iter += 1
@LOOPS.register_module() @LOOPS.register_module()
......
...@@ -311,7 +311,15 @@ class Runner: ...@@ -311,7 +311,15 @@ class Runner:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
self.logger = self.build_logger(log_level=log_level) self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction # Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
# have access to the runner can get iteration or epoch information
# from `message_hub`. For example, models can get the latest created
# `message_hub` by
# `self.message_hub=MessageHub.get_current_instance()` and then get
# current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
# See `MessageHub` and `ManagerMixin` for more details.
self.message_hub = self.build_message_hub() self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
...@@ -407,11 +415,26 @@ class Runner: ...@@ -407,11 +415,26 @@ class Runner:
"""int: Current epoch.""" """int: Current epoch."""
return self._epoch return self._epoch
@epoch.setter
def epoch(self, epoch: int):
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
self._epoch = epoch
# To allow components that cannot access runner to get current epoch.
self.message_hub.update_info('epoch', epoch)
@property @property
def iter(self): def iter(self):
"""int: Current epoch.""" """int: Current iteration."""
return self._iter return self._iter
@iter.setter
def iter(self, iter: int):
"""Update iter and synchronize iter in :attr:`message_hub`."""
self._iter = iter
# To allow components that cannot access runner to get current
# iteration.
self.message_hub.update_info('iter', iter)
@property @property
def launcher(self): def launcher(self):
"""str: Way to launcher multi processes.""" """str: Way to launcher multi processes."""
......
...@@ -754,6 +754,9 @@ class TestRunner(TestCase): ...@@ -754,6 +754,9 @@ class TestRunner(TestCase):
assert isinstance(runner.train_loop, EpochBasedTrainLoop) assert isinstance(runner.train_loop, EpochBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
assert runner.epoch == runner.message_hub.get_info('epoch')
for result, target, in zip(epoch_results, epoch_targets): for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets): for result, target, in zip(iter_results, iter_targets):
...@@ -786,6 +789,7 @@ class TestRunner(TestCase): ...@@ -786,6 +789,7 @@ class TestRunner(TestCase):
runner.train() runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
self.assertEqual(len(epoch_results), 1) self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0) self.assertEqual(epoch_results[0], 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment