From 45567b1d1cb71c324701b29945693cbc082bc8f3 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Thu, 21 Apr 2022 11:45:03 +0800
Subject: [PATCH] automaticaly update iter and epoch in message_hub (#168)

* automatic update iter and epoch in message_hub

* add docstring

* Update comment and docstring

* Fix as comment

* Fix docstring and comment

* refine comments
---
 mmengine/runner/loops.py         |  6 +++---
 mmengine/runner/runner.py        | 27 +++++++++++++++++++++++++--
 tests/test_runner/test_runner.py |  4 ++++
 3 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index eb5c3454..62d99317 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -60,7 +60,7 @@ class EpochBasedTrainLoop(BaseLoop):
             self.run_iter(idx, data_batch)
 
         self.runner.call_hook('after_train_epoch')
-        self.runner._epoch += 1
+        self.runner.epoch += 1
 
     def run_iter(self, idx,
                  data_batch: Sequence[Tuple[Any, BaseDataElement]]) -> None:
@@ -85,7 +85,7 @@ class EpochBasedTrainLoop(BaseLoop):
             data_batch=data_batch,
             outputs=self.runner.outputs)
 
-        self.runner._iter += 1
+        self.runner.iter += 1
 
 
 @LOOPS.register_module()
@@ -154,7 +154,7 @@ class IterBasedTrainLoop(BaseLoop):
             batch_idx=self.runner._iter,
             data_batch=data_batch,
             outputs=self.runner.outputs)
-        self.runner._iter += 1
+        self.runner.iter += 1
 
 
 @LOOPS.register_module()
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index a833570e..a1b11511 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -311,7 +311,15 @@ class Runner:
             self._experiment_name = self.timestamp
 
         self.logger = self.build_logger(log_level=log_level)
-        # message hub used for component interaction
+        # Build `message_hub` for communication among components.
+        # `message_hub` can store log scalars (loss, learning rate) and
+        # runtime information (iter and epoch). Those components that do not
+        # have access to the runner can get iteration or epoch information
+        # from `message_hub`. For example, models can get the latest created
+        # `message_hub` by
+        # `self.message_hub=MessageHub.get_current_instance()` and then get
+        # current epoch by `cur_epoch = self.message_hub.get_info('epoch')`.
+        # See `MessageHub` and `ManagerMixin` for more details.
         self.message_hub = self.build_message_hub()
         # writer used for writing log or visualizing all kinds of data
         self.writer = self.build_writer(writer)
@@ -407,11 +415,26 @@ class Runner:
         """int: Current epoch."""
         return self._epoch
 
+    @epoch.setter
+    def epoch(self, epoch: int):
+        """Update epoch and synchronize epoch in :attr:`message_hub`."""
+        self._epoch = epoch
+        # To allow components that cannot access runner to get current epoch.
+        self.message_hub.update_info('epoch', epoch)
+
     @property
     def iter(self):
-        """int: Current epoch."""
+        """int: Current iteration."""
         return self._iter
 
+    @iter.setter
+    def iter(self, iter: int):
+        """Update iter and synchronize iter in :attr:`message_hub`."""
+        self._iter = iter
+        # To allow components that cannot access runner to get current
+        # iteration.
+        self.message_hub.update_info('iter', iter)
+
     @property
     def launcher(self):
         """str: Way to launcher multi processes."""
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index 25fbdbcf..7b497c8e 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -754,6 +754,9 @@ class TestRunner(TestCase):
 
         assert isinstance(runner.train_loop, EpochBasedTrainLoop)
 
+        assert runner.iter == runner.message_hub.get_info('iter')
+        assert runner.epoch == runner.message_hub.get_info('epoch')
+
         for result, target, in zip(epoch_results, epoch_targets):
             self.assertEqual(result, target)
         for result, target, in zip(iter_results, iter_targets):
@@ -786,6 +789,7 @@ class TestRunner(TestCase):
         runner.train()
 
         assert isinstance(runner.train_loop, IterBasedTrainLoop)
+        assert runner.iter == runner.message_hub.get_info('iter')
 
         self.assertEqual(len(epoch_results), 1)
         self.assertEqual(epoch_results[0], 0)
-- 
GitLab