Skip to content
Snippets Groups Projects
Unverified Commit 45567b1d authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

automaticaly update iter and epoch in message_hub (#168)

* automatic update iter and epoch in message_hub

* add docstring

* Update comment and docstring

* Fix as comment

* Fix docstring and comment

* refine comments
parent 53101a1a
No related branches found
No related tags found
No related merge requests found
......@@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self.runner._epoch += 1
self.runner.epoch += 1
def run_iter(self, idx,
data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
......@@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch=data_batch,
outputs=self.runner.outputs)
self.runner._iter += 1
self.runner.iter += 1
@LOOPS.register_module()
......@@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop):
batch_idx=self.runner._iter,
data_batch=data_batch,
outputs=self.runner.outputs)
self.runner._iter += 1
self.runner.iter += 1
@LOOPS.register_module()
......
......@@ -311,7 +311,15 @@ class Runner:
self._experiment_name = self.timestamp
self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
# have access to the runner can get iteration or epoch information
# from `message_hub`. For example, models can get the latest created
# `message_hub` by
# `self.message_hub=MessageHub.get_current_instance()` and then get
# current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
# See `MessageHub` and `ManagerMixin` for more details.
self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer)
......@@ -407,11 +415,26 @@ class Runner:
"""int: Current epoch."""
return self._epoch
@epoch.setter
def epoch(self, epoch: int):
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
self._epoch = epoch
# To allow components that cannot access runner to get current epoch.
self.message_hub.update_info('epoch', epoch)
@property
def iter(self):
"""int: Current epoch."""
"""int: Current iteration."""
return self._iter
@iter.setter
def iter(self, iter: int):
"""Update iter and synchronize iter in :attr:`message_hub`."""
self._iter = iter
# To allow components that cannot access runner to get current
# iteration.
self.message_hub.update_info('iter', iter)
@property
def launcher(self):
"""str: Way to launcher multi processes."""
......
......@@ -754,6 +754,9 @@ class TestRunner(TestCase):
assert isinstance(runner.train_loop, EpochBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
assert runner.epoch == runner.message_hub.get_info('epoch')
for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target)
for result, target, in zip(iter_results, iter_targets):
......@@ -786,6 +789,7 @@ class TestRunner(TestCase):
runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop)
assert runner.iter == runner.message_hub.get_info('iter')
self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment