From 45567b1d1cb71c324701b29945693cbc082bc8f3 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 21 Apr 2022 11:45:03 +0800 Subject: [PATCH] automaticaly update iter and epoch in message_hub (#168) * automatic update iter and epoch in message_hub * add docstring * Update comment and docstring * Fix as comment * Fix docstring and comment * refine comments --- mmengine/runner/loops.py | 6 +++--- mmengine/runner/runner.py | 27 +++++++++++++++++++++++++-- tests/test_runner/test_runner.py | 4 ++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index eb5c3454..62d99317 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop): self.run_iter(idx, data_batch) self.runner.call_hook('after_train_epoch') - self.runner._epoch += 1 + self.runner.epoch += 1 def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: @@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop): data_batch=data_batch, outputs=self.runner.outputs) - self.runner._iter += 1 + self.runner.iter += 1 @LOOPS.register_module() @@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop): batch_idx=self.runner._iter, data_batch=data_batch, outputs=self.runner.outputs) - self.runner._iter += 1 + self.runner.iter += 1 @LOOPS.register_module() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index a833570e..a1b11511 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -311,7 +311,15 @@ class Runner: self._experiment_name = self.timestamp self.logger = self.build_logger(log_level=log_level) - # message hub used for component interaction + # Build `message_hub` for communication among components. + # `message_hub` can store log scalars (loss, learning rate) and + # runtime information (iter and epoch). Those components that do not + # have access to the runner can get iteration or epoch information + # from `message_hub`. For example, models can get the latest created + # `message_hub` by + # `self.message_hub=MessageHub.get_current_instance()` and then get + # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. + # See `MessageHub` and `ManagerMixin` for more details. self.message_hub = self.build_message_hub() # writer used for writing log or visualizing all kinds of data self.writer = self.build_writer(writer) @@ -407,11 +415,26 @@ class Runner: """int: Current epoch.""" return self._epoch + @epoch.setter + def epoch(self, epoch: int): + """Update epoch and synchronize epoch in :attr:`message_hub`.""" + self._epoch = epoch + # To allow components that cannot access runner to get current epoch. + self.message_hub.update_info('epoch', epoch) + @property def iter(self): - """int: Current epoch.""" + """int: Current iteration.""" return self._iter + @iter.setter + def iter(self, iter: int): + """Update iter and synchronize iter in :attr:`message_hub`.""" + self._iter = iter + # To allow components that cannot access runner to get current + # iteration. + self.message_hub.update_info('iter', iter) + @property def launcher(self): """str: Way to launcher multi processes.""" diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 25fbdbcf..7b497c8e 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -754,6 +754,9 @@ class TestRunner(TestCase): assert isinstance(runner.train_loop, EpochBasedTrainLoop) + assert runner.iter == runner.message_hub.get_info('iter') + assert runner.epoch == runner.message_hub.get_info('epoch') + for result, target, in zip(epoch_results, epoch_targets): self.assertEqual(result, target) for result, target, in zip(iter_results, iter_targets): @@ -786,6 +789,7 @@ class TestRunner(TestCase): runner.train() assert isinstance(runner.train_loop, IterBasedTrainLoop) + assert runner.iter == runner.message_hub.get_info('iter') self.assertEqual(len(epoch_results), 1) self.assertEqual(epoch_results[0], 0) -- GitLab