diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index eb5c3454bb38c1b686b57518f205715a55ce2001..62d9931704119ed837a7ec235a2dd099ac407e06 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop): self.run_iter(idx, data_batch) self.runner.call_hook('after_train_epoch') - self.runner._epoch += 1 + self.runner.epoch += 1 def run_iter(self, idx, data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None: @@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop): data_batch=data_batch, outputs=self.runner.outputs) - self.runner._iter += 1 + self.runner.iter += 1 @LOOPS.register_module() @@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop): batch_idx=self.runner._iter, data_batch=data_batch, outputs=self.runner.outputs) - self.runner._iter += 1 + self.runner.iter += 1 @LOOPS.register_module() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index a833570e235818ade3c7832aa78696b74bd655c4..a1b11511791a85d91285b0fe961f437a29e036af 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -311,7 +311,15 @@ class Runner: self._experiment_name = self.timestamp self.logger = self.build_logger(log_level=log_level) - # message hub used for component interaction + # Build `message_hub` for communication among components. + # `message_hub` can store log scalars (loss, learning rate) and + # runtime information (iter and epoch). Those components that do not + # have access to the runner can get iteration or epoch information + # from `message_hub`. For example, models can get the latest created + # `message_hub` by + # `self.message_hub=MessageHub.get_current_instance()` and then get + # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`. + # See `MessageHub` and `ManagerMixin` for more details. self.message_hub = self.build_message_hub() # writer used for writing log or visualizing all kinds of data self.writer = self.build_writer(writer) @@ -407,11 +415,26 @@ class Runner: """int: Current epoch.""" return self._epoch + @epoch.setter + def epoch(self, epoch: int): + """Update epoch and synchronize epoch in :attr:`message_hub`.""" + self._epoch = epoch + # To allow components that cannot access runner to get current epoch. + self.message_hub.update_info('epoch', epoch) + @property def iter(self): - """int: Current epoch.""" + """int: Current iteration.""" return self._iter + @iter.setter + def iter(self, iter: int): + """Update iter and synchronize iter in :attr:`message_hub`.""" + self._iter = iter + # To allow components that cannot access runner to get current + # iteration. + self.message_hub.update_info('iter', iter) + @property def launcher(self): """str: Way to launcher multi processes.""" diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 25fbdbcfa22c129247e3a7c15aa7926dd3bb5cae..7b497c8e54eb6c503f9e5efa12adc8b3dc47c18a 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -754,6 +754,9 @@ class TestRunner(TestCase): assert isinstance(runner.train_loop, EpochBasedTrainLoop) + assert runner.iter == runner.message_hub.get_info('iter') + assert runner.epoch == runner.message_hub.get_info('epoch') + for result, target, in zip(epoch_results, epoch_targets): self.assertEqual(result, target) for result, target, in zip(iter_results, iter_targets): @@ -786,6 +789,7 @@ class TestRunner(TestCase): runner.train() assert isinstance(runner.train_loop, IterBasedTrainLoop) + assert runner.iter == runner.message_hub.get_info('iter') self.assertEqual(len(epoch_results), 1) self.assertEqual(epoch_results[0], 0)