From a7961407e4da6ee7aaf630792b851f42c6ba1706 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Sun, 13 Mar 2022 16:48:09 +0800
Subject: [PATCH] [Refactor] Refactor the interfaces of Hook and its subclassed
 (#117)

* Fix hook

* Fix

* Fix docs

* FIx

* Fix

* Fix as comment

* update

* Fix hook

* Fix hook

* Fix hook

* Fix itertimerhook

* Fix iter_timer_hook

* Fix

* Fix

* fix logger hook

* Fix loggerhook

* update cur_dataloader

* Fix docstring

* Fix docstring

* Fix as commet

* Fix as commet

* Fix as comment

* rename is_last_epoch, enhance and add after_val before_val .etc

* fix typo in docstring

* remove resolved TODO

* refactor docstring
---
 mmengine/hooks/checkpoint_hook.py         |  10 +-
 mmengine/hooks/empty_cache_hook.py        |  31 +--
 mmengine/hooks/hook.py                    | 242 ++++++++++++----------
 mmengine/hooks/iter_timer_hook.py         |  27 ++-
 mmengine/hooks/logger_hook.py             |   9 +-
 mmengine/hooks/sampler_seed_hook.py       |  10 +-
 mmengine/hooks/sync_buffer_hook.py        |   2 +-
 tests/test_hook/test_empty_cache_hook.py  |   6 +-
 tests/test_hook/test_hook.py              |  36 ++--
 tests/test_hook/test_iter_timer_hook.py   |  14 +-
 tests/test_hook/test_logger_hook.py       |   8 +-
 tests/test_hook/test_sampler_seed_hook.py |  24 +--
 tests/test_hook/test_sync_buffers_hook.py |   2 +-
 13 files changed, 235 insertions(+), 186 deletions(-)

diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index d8c75d1f..313900b7 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -119,9 +119,8 @@ class CheckpointHook(Hook):
         # save checkpoint for following cases:
         # 1. every ``self.interval`` epochs
         # 2. reach the last epoch of training
-        if self.every_n_epochs(
-                runner, self.interval) or (self.save_last
-                                           and self.is_last_epoch(runner)):
+        if self.every_n_epochs(runner, self.interval) or (
+                self.save_last and self.is_last_train_epoch(runner)):
             runner.logger.info(f'Saving checkpoint at \
                     {runner.epoch + 1} epochs')
             if self.sync_buffer:
@@ -187,9 +186,8 @@ class CheckpointHook(Hook):
         # save checkpoint for following cases:
         # 1. every ``self.interval`` iterations
         # 2. reach the last iteration of training
-        if self.every_n_iters(
-                runner, self.interval) or (self.save_last
-                                           and self.is_last_iter(runner)):
+        if self.every_n_iters(runner, self.interval) or \
+                (self.save_last and self.is_last_iter(runner, mode='train')):
             runner.logger.info(f'Saving checkpoint at \
                     {runner.iter + 1} iterations')
             if self.sync_buffer:
diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py
index a37de337..cb20a744 100644
--- a/mmengine/hooks/empty_cache_hook.py
+++ b/mmengine/hooks/empty_cache_hook.py
@@ -30,16 +30,16 @@ class EmptyCacheHook(Hook):
                  before_epoch: bool = False,
                  after_epoch: bool = True,
                  after_iter: bool = False) -> None:
-        self._before_epoch = before_epoch
-        self._after_epoch = after_epoch
-        self._after_iter = after_iter
+        self._do_before_epoch = before_epoch
+        self._do_after_epoch = after_epoch
+        self._do_after_iter = after_iter
 
-    def after_iter(self,
-                   runner,
-                   data_batch: DATA_BATCH = None,
-                   outputs:
-                   Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
-            -> None:
+    def _after_iter(self,
+                    runner,
+                    data_batch: DATA_BATCH = None,
+                    outputs: Optional[Union[dict,
+                                            Sequence[BaseDataSample]]] = None,
+                    mode: str = 'train') -> None:
         """Empty cache after an iteration.
 
         Args:
@@ -48,24 +48,27 @@ class EmptyCacheHook(Hook):
                 from dataloader. Defaults to None.
             outputs (dict or sequence, optional): Outputs from model.
                 Defaults to None.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
-        if self._after_iter:
+        if self._do_after_iter:
             torch.cuda.empty_cache()
 
-    def before_epoch(self, runner) -> None:
+    def _before_epoch(self, runner, mode: str = 'train') -> None:
         """Empty cache before an epoch.
 
         Args:
             runner (Runner): The runner of the training process.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
-        if self._before_epoch:
+        if self._do_before_epoch:
             torch.cuda.empty_cache()
 
-    def after_epoch(self, runner) -> None:
+    def _after_epoch(self, runner, mode: str = 'train') -> None:
         """Empty cache after an epoch.
 
         Args:
             runner (Runner): The runner of the training process.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
-        if self._after_epoch:
+        if self._do_after_epoch:
             torch.cuda.empty_cache()
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
index 91729375..7eeb5e54 100644
--- a/mmengine/hooks/hook.py
+++ b/mmengine/hooks/hook.py
@@ -16,20 +16,20 @@ class Hook:
 
     def before_run(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations before the training process.
+        operations before the training validation or testing process.
 
         Args:
-            runner (Runner): The runner of the training/validation/testing
+            runner (Runner): The runner of the training, validation or testing
                 process.
         """
         pass
 
     def after_run(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations after the training process.
+        operations before the training validation or testing process.
 
         Args:
-            runner (Runner): The runner of the training/validation/testing
+            runner (Runner): The runner of the training, validation or testing
                 process.
         """
         pass
@@ -54,7 +54,7 @@ class Hook:
 
     def before_val(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations before val.
+        operations before validation.
 
         Args:
             runner (Runner): The runner of the validation process.
@@ -63,7 +63,7 @@ class Hook:
 
     def after_val(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations after val.
+        operations after validation.
 
         Args:
             runner (Runner): The runner of the validation process.
@@ -72,7 +72,7 @@ class Hook:
 
     def before_test(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations before test.
+        operations before testing.
 
         Args:
             runner (Runner): The runner of the testing process.
@@ -81,67 +81,21 @@ class Hook:
 
     def after_test(self, runner) -> None:
         """All subclasses should override this method, if they need any
-        operations after test.
+        operations after testing.
 
         Args:
             runner (Runner): The runner of the testing process.
         """
         pass
 
-    def before_epoch(self, runner) -> None:
-        """All subclasses should override this method, if they need any
-        operations before each epoch.
-
-        Args:
-            runner (Runner): The runner of the training process.
-        """
-        pass
-
-    def after_epoch(self, runner) -> None:
-        """All subclasses should override this method, if they need any
-        operations after each epoch.
-
-        Args:
-            runner (Runner): The runner of the training process.
-        """
-        pass
-
-    def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
-        """All subclasses should override this method, if they need any
-        operations before each iter.
-
-        Args:
-            runner (Runner): The runner of the training process.
-            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
-                Data from dataloader. Defaults to None.
-        """
-        pass
-
-    def after_iter(self,
-                   runner,
-                   data_batch: DATA_BATCH = None,
-                   outputs:
-                   Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-            -> None:
-        """All subclasses should override this method, if they need any
-        operations after each epoch.
-
-        Args:
-            runner (Runner): The runner of the training process.
-            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
-                Data from dataloader. Defaults to None.
-            outputs (dict or sequence, optional): Outputs from model. Defaults
-                to None.
-        """
-        pass
-
     def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
         """All subclasses should override this method, if they need any
         operations before saving the checkpoint.
 
         Args:
-            runner (Runner): The runner of the training process.
-            checkpoints (dict): Model's checkpoint.
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            checkpoint (dict): Model's checkpoint.
         """
         pass
 
@@ -150,8 +104,9 @@ class Hook:
         operations after loading the checkpoint.
 
         Args:
-            runner (Runner): The runner of the training process.
-            checkpoints (dict): Model's checkpoint.
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            checkpoint (dict): Model's checkpoint.
         """
         pass
 
@@ -162,25 +117,25 @@ class Hook:
         Args:
             runner (Runner): The runner of the training process.
         """
-        self.before_epoch(runner)
+        self._before_epoch(runner, mode='train')
 
     def before_val_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
         operations before each validation epoch.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the validation process.
         """
-        self.before_epoch(runner)
+        self._before_epoch(runner, mode='val')
 
     def before_test_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
         operations before each test epoch.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the testing process.
         """
-        self.before_epoch(runner)
+        self._before_epoch(runner, mode='test')
 
     def after_train_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
@@ -189,25 +144,25 @@ class Hook:
         Args:
             runner (Runner): The runner of the training process.
         """
-        self.after_epoch(runner)
+        self._after_epoch(runner, mode='train')
 
     def after_val_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
         operations after each validation epoch.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the validation process.
         """
-        self.after_epoch(runner)
+        self._after_epoch(runner, mode='val')
 
     def after_test_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
         operations after each test epoch.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the testing process.
         """
-        self.after_epoch(runner)
+        self._after_epoch(runner, mode='test')
 
     def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
         """All subclasses should override this method, if they need any
@@ -218,29 +173,29 @@ class Hook:
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
         """
-        self.before_iter(runner, data_batch=None)
+        self._before_iter(runner, data_batch=data_batch, mode='train')
 
     def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each validation iteration.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the validation process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
         """
-        self.before_iter(runner, data_batch=None)
+        self._before_iter(runner, data_batch=data_batch, mode='val')
 
     def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each test iteration.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the testing process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
         """
-        self.before_iter(runner, data_batch=None)
+        self._before_iter(runner, data_batch=data_batch, mode='test')
 
     def after_train_iter(self,
                          runner,
@@ -256,7 +211,8 @@ class Hook:
             outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
-        self.after_iter(runner, data_batch=None, outputs=None)
+        self._after_iter(
+            runner, data_batch=data_batch, outputs=outputs, mode='train')
 
     def after_val_iter(self,
                        runner,
@@ -267,13 +223,14 @@ class Hook:
         operations after each validation iteration.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the validation process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
             outputs (dict or sequence, optional): Outputs from
                 model. Defaults to None.
         """
-        self.after_iter(runner, data_batch=None, outputs=None)
+        self._after_iter(
+            runner, data_batch=data_batch, outputs=outputs, mode='val')
 
     def after_test_iter(
             self,
@@ -284,48 +241,108 @@ class Hook:
         operations after each test iteration.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the training  process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
             outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
-        self.after_iter(runner, data_batch=None, outputs=None)
+        self._after_iter(
+            runner, data_batch=data_batch, outputs=outputs, mode='test')
+
+    def _before_epoch(self, runner, mode: str = 'train') -> None:
+        """All subclasses should override this method, if they need any
+        operations before each epoch.
+
+        Args:
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            mode (str): Current mode of runner. Defaults to 'train'.
+        """
+        pass
+
+    def _after_epoch(self, runner, mode: str = 'train') -> None:
+        """All subclasses should override this method, if they need any
+        operations after each epoch.
+
+        Args:
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            mode (str): Current mode of runner. Defaults to 'train'.
+        """
+        pass
+
+    def _before_iter(self,
+                     runner,
+                     data_batch: DATA_BATCH = None,
+                     mode: str = 'train') -> None:
+        """All subclasses should override this method, if they need any
+        operations before each iter.
+
+        Args:
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
+            mode (str): Current mode of runner. Defaults to 'train'.
+        """
+        pass
+
+    def _after_iter(self,
+                    runner,
+                    data_batch: DATA_BATCH = None,
+                    outputs: Optional[Union[Sequence[BaseDataSample],
+                                            dict]] = None,
+                    mode: str = 'train') -> None:
+        """All subclasses should override this method, if they need any
+        operations after each epoch.
+
+        Args:
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+            mode (str): Current mode of runner. Defaults to 'train'.
+        """
+        pass
 
     def every_n_epochs(self, runner, n: int) -> bool:
-        """Test whether or not current epoch can be evenly divided by n.
+        """Test whether current epoch can be evenly divided by n.
 
         Args:
-            runner (Runner): The runner of the training process.
-            n (int): Whether or not current epoch can be evenly divided by n.
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            n (int): Whether current epoch can be evenly divided by n.
 
         Returns:
-            bool: whether or not current epoch can be evenly divided by n.
+            bool: Whether current epoch can be evenly divided by n.
         """
         return (runner.epoch + 1) % n == 0 if n > 0 else False
 
     def every_n_inner_iters(self, runner, n: int) -> bool:
-        """Test whether or not current inner iteration can be evenly divided by
-        n.
+        """Test whether current inner iteration can be evenly divided by n.
 
         Args:
-            runner (Runner): The runner of the training process.
-            n (int): Whether or not current inner iteration can be evenly
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            n (int): Whether current inner iteration can be evenly
                 divided by n.
 
         Returns:
-            bool: whether or not current inner iteration can be evenly
+            bool: Whether current inner iteration can be evenly
             divided by n.
         """
         return (runner.inner_iter + 1) % n == 0 if n > 0 else False
 
     def every_n_iters(self, runner, n: int) -> bool:
-        """Test whether or not current iteration can be evenly divided by n.
+        """Test whether current iteration can be evenly divided by n.
 
         Args:
-            runner (Runner): The runner of the training process.
-            n (int): Whether or not current iteration can be
-                evenly divided by n.
+            runner (Runner): The runner of the training, validation or testing
+                process.
+            n (int): Whether current iteration can be evenly divided by n.
 
         Returns:
             bool: Return True if the current iteration can be evenly divided
@@ -334,35 +351,46 @@ class Hook:
         return (runner.iter + 1) % n == 0 if n > 0 else False
 
     def end_of_epoch(self, runner) -> bool:
-        """Check whether the current epoch reaches the `max_epochs` or not.
+        """Check whether the current iteration reaches the last iteration of
+        current dataloader.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the training, validation or testing
+                process.
 
         Returns:
-            bool: whether the end of current epoch or not.
+            bool: Whether reaches the end of current epoch or not.
         """
-        return runner.inner_iter + 1 == len(runner.data_loader)
+        return runner.inner_iter + 1 == len(runner.cur_dataloader)
 
-    def is_last_epoch(self, runner) -> bool:
-        """Test whether or not current epoch is the last epoch.
+    def is_last_train_epoch(self, runner) -> bool:
+        """Test whether current epoch is the last train epoch.
 
         Args:
             runner (Runner): The runner of the training process.
 
         Returns:
-            bool: bool: Return True if the current epoch reaches the
-            `max_epochs`, otherwise False.
+            bool: Whether reaches the end of training epoch.
         """
-        return runner.epoch + 1 == runner._max_epochs
+        return runner.epoch + 1 == runner.train_loop.max_epochs
 
-    def is_last_iter(self, runner) -> bool:
-        """Test whether or not current epoch is the last iteration.
+    def is_last_iter(self, runner, mode='train') -> bool:
+        """Test whether current iteration is the last iteration.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the training, validation or testing
+                process.
 
         Returns:
-            bool: whether or not current iteration is the last iteration.
-        """
-        return runner.iter + 1 == runner._max_iters
+            bool: Whether current iteration is the last iteration.
+            mode (str): Current mode of runner. Defaults to 'train'.
+        """
+        if mode == 'train':
+            return runner.iter + 1 == runner.train_loop.max_iters
+        elif mode == 'val':
+            return runner.iter + 1 == runner.val_loop.max_iters
+        elif mode == 'test':
+            return runner.iter + 1 == runner.test_loop.max_iters
+        else:
+            raise ValueError('mode should be train, val or test but got'
+                             f'{mode}')
diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py
index 72701504..824774ed 100644
--- a/mmengine/hooks/iter_timer_hook.py
+++ b/mmengine/hooks/iter_timer_hook.py
@@ -18,30 +18,37 @@ class IterTimerHook(Hook):
 
     priority = 'NORMAL'
 
-    def before_epoch(self, runner) -> None:
+    def _before_epoch(self, runner, mode: str = 'train') -> None:
         """Record time flag before start a epoch.
 
         Args:
             runner (Runner): The runner of the training process.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
         self.t = time.time()
 
-    def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
+    def _before_iter(self,
+                     runner,
+                     data_batch: DATA_BATCH = None,
+                     mode: str = 'train') -> None:
         """Logging time for loading data and update the time flag.
 
         Args:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
                 from dataloader. Defaults to None.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
         # TODO: update for new logging system
-        runner.log_buffer.update({'data_time': time.time() - self.t})
+        runner.message_hub.update_log(f'{mode}/data_time',
+                                      time.time() - self.t)
 
-    def after_iter(self,
-                   runner,
-                   data_batch: DATA_BATCH = None,
-                   outputs:
-                   Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
+    def _after_iter(self,
+                    runner,
+                    data_batch: DATA_BATCH = None,
+                    outputs:
+                    Optional[Union[dict, Sequence[BaseDataSample]]] = None,
+                    mode: str = 'train') \
             -> None:
         """Logging time for a iteration and update the time flag.
 
@@ -51,7 +58,9 @@ class IterTimerHook(Hook):
                 from dataloader. Defaults to None.
             outputs (dict or sequence, optional): Outputs from model. Defaults
                 to None.
+            mode (str): Current mode of runner. Defaults to 'train'.
         """
         # TODO: update for new logging system
-        runner.log_buffer.update({'time': time.time() - self.t})
+
+        runner.message_hub.update_log(f'{mode}/time', time.time() - self.t)
         self.t = time.time()
diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index b557c2d0..9e2518bf 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -264,14 +264,15 @@ class LoggerHook(Hook):
         # by iter:  Iter [100/100000]
         if self.by_epoch:
             log_str = f'Epoch [{cur_epoch}]' \
-                      f'[{cur_iter}/{len(runner.data_loader)}]\t'
+                      f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
         else:
-            log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
+            log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
         log_str += f'{lr_momentum_str}, '
         # Calculate eta time.
         self.time_sec_tot += (tag['time'] * self.interval)
         time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
-        eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+        eta_sec = time_sec_avg * (
+            runner.train_loop.max_iters - runner.iter - 1)
         eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
         log_str += f'eta: {eta_str}, '
         log_str += f'time: {tag["time"]:.3f}, ' \
@@ -302,7 +303,7 @@ class LoggerHook(Hook):
         """
         tag = self._collect_info(runner, 'val')
         # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
-        eval_iter = len(runner.data_loader)
+        eval_iter = len(runner.cur_dataloader)
         cur_iter = self._get_iter(runner)
         cur_epoch = self._get_epoch(runner, 'val')
         # val/test time
diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py
index d7d243d8..eed3fa90 100644
--- a/mmengine/hooks/sampler_seed_hook.py
+++ b/mmengine/hooks/sampler_seed_hook.py
@@ -14,15 +14,15 @@ class DistSamplerSeedHook(Hook):
 
     priority = 'NORMAL'
 
-    def before_epoch(self, runner) -> None:
+    def before_train_epoch(self, runner, mode: str = 'train') -> None:
         """Set the seed for sampler and batch_sampler.
 
         Args:
             runner (Runner): The runner of the training process.
         """
-        if hasattr(runner.data_loader.sampler, 'set_epoch'):
+        if hasattr(runner.cur_dataloader.sampler, 'set_epoch'):
             # in case the data loader uses `SequentialSampler` in Pytorch
-            runner.data_loader.sampler.set_epoch(runner.epoch)
-        elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+            runner.cur_dataloader.sampler.set_epoch(runner.epoch)
+        elif hasattr(runner.cur_dataloader.batch_sampler.sampler, 'set_epoch'):
             # batch sampler in pytorch warps the sampler as its attributes.
-            runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
+            runner.cur_dataloader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py
index 9aa6402d..37b62f98 100644
--- a/mmengine/hooks/sync_buffer_hook.py
+++ b/mmengine/hooks/sync_buffer_hook.py
@@ -89,7 +89,7 @@ class SyncBuffersHook(Hook):
     def __init__(self) -> None:
         self.distributed = dist.IS_DIST
 
-    def after_epoch(self, runner) -> None:
+    def after_train_epoch(self, runner) -> None:
         """All-reduce model buffers at the end of each epoch.
 
         Args:
diff --git a/tests/test_hook/test_empty_cache_hook.py b/tests/test_hook/test_empty_cache_hook.py
index a09806d0..dc8ce8fc 100644
--- a/tests/test_hook/test_empty_cache_hook.py
+++ b/tests/test_hook/test_empty_cache_hook.py
@@ -9,6 +9,6 @@ class TestEmptyCacheHook:
     def test_emtpy_cache_hook(self):
         Hook = EmptyCacheHook(True, True, True)
         Runner = Mock()
-        Hook.after_iter(Runner)
-        Hook.before_epoch(Runner)
-        Hook.after_epoch(Runner)
+        Hook._after_iter(Runner)
+        Hook._before_epoch(Runner)
+        Hook._after_epoch(Runner)
diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py
index 5884a161..d55f8a66 100644
--- a/tests/test_hook/test_hook.py
+++ b/tests/test_hook/test_hook.py
@@ -1,6 +1,8 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from unittest.mock import Mock
 
+import pytest
+
 from mmengine.hooks import Hook
 
 
@@ -19,25 +21,25 @@ class TestHook:
     def test_before_epoch(self):
         hook = Hook()
         runner = Mock()
-        hook.before_epoch(runner)
+        hook._before_epoch(runner)
 
     def test_after_epoch(self):
         hook = Hook()
         runner = Mock()
-        hook.after_epoch(runner)
+        hook._after_epoch(runner)
 
     def test_before_iter(self):
         hook = Hook()
         runner = Mock()
         data_batch = {}
-        hook.before_iter(runner, data_batch)
+        hook._before_iter(runner, data_batch)
 
     def test_after_iter(self):
         hook = Hook()
         runner = Mock()
         data_batch = {}
         outputs = {}
-        hook.after_iter(runner, data_batch, outputs)
+        hook._after_iter(runner, data_batch, outputs)
 
     def test_before_save_checkpoint(self):
         hook = Hook()
@@ -161,7 +163,8 @@ class TestHook:
 
         # last inner iter
         runner.inner_iter = 1
-        runner.data_loader.__len__ = Mock(return_value=2)
+        runner.cur_dataloader.__len__ = Mock(return_value=2)
+        runner.cur_dataloader.__len__ = Mock(return_value=2)
         return_val = hook.end_of_epoch(runner)
         assert return_val
 
@@ -170,19 +173,19 @@ class TestHook:
         return_val = hook.end_of_epoch(runner)
         assert not return_val
 
-    def test_is_last_epoch(self):
+    def test_is_last_train_epoch(self):
         hook = Hook()
         runner = Mock()
 
         # last epoch
         runner.epoch = 1
-        runner._max_epochs = 2
-        return_val = hook.is_last_epoch(runner)
+        runner.train_loop.max_epochs = 2
+        return_val = hook.is_last_train_epoch(runner)
         assert return_val
 
         # not the last epoch
-        runner.epoch = 0
-        return_val = hook.is_last_epoch(runner)
+        runner.train_loop.max_epochs = 0
+        return_val = hook.is_last_train_epoch(runner)
         assert not return_val
 
     def test_is_last_iter(self):
@@ -191,11 +194,18 @@ class TestHook:
 
         # last iter
         runner.iter = 1
-        runner._max_iters = 2
+        runner.train_loop.max_iters = 2
         return_val = hook.is_last_iter(runner)
         assert return_val
 
         # not the last iter
-        runner.iter = 0
-        return_val = hook.is_last_iter(runner)
+        runner.val_loop.max_iters = 0
+        return_val = hook.is_last_iter(runner, mode='val')
         assert not return_val
+
+        runner.test_loop.max_iters = 0
+        return_val = hook.is_last_iter(runner, mode='test')
+        assert not return_val
+
+        with pytest.raises(ValueError):
+            hook.is_last_iter(runner, mode='error_mode')
diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py
index 5e3b6e71..8b20d4f1 100644
--- a/tests/test_hook/test_iter_timer_hook.py
+++ b/tests/test_hook/test_iter_timer_hook.py
@@ -9,21 +9,21 @@ class TestIterTimerHook:
     def test_before_epoch(self):
         Hook = IterTimerHook()
         Runner = Mock()
-        Hook.before_epoch(Runner)
+        Hook._before_epoch(Runner)
         assert isinstance(Hook.t, float)
 
     def test_before_iter(self):
         Hook = IterTimerHook()
         Runner = Mock()
         Runner.log_buffer = dict()
-        Hook.before_epoch(Runner)
-        Hook.before_iter(Runner)
-        assert 'data_time' in Runner.log_buffer
+        Hook._before_epoch(Runner)
+        Hook._before_iter(Runner)
+        Runner.message_hub.update_log.assert_called()
 
     def test_after_iter(self):
         Hook = IterTimerHook()
         Runner = Mock()
         Runner.log_buffer = dict()
-        Hook.before_epoch(Runner)
-        Hook.after_iter(Runner)
-        assert 'time' in Runner.log_buffer
+        Hook._before_epoch(Runner)
+        Hook._after_iter(Runner)
+        Runner.message_hub.update_log.assert_called()
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 90faad77..d92ac9bc 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -110,7 +110,7 @@ class TestLoggerHook:
         # Test end of the epoch.
         logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
         logger_hook._log_train = MagicMock()
-        runner.data_loader = [0] * 5
+        runner.cur_dataloader = [0] * 5
         runner.inner_iter = 4
         logger_hook.after_train_iter(runner)
         logger_hook._log_train.assert_called()
@@ -155,7 +155,7 @@ class TestLoggerHook:
         out, _ = capsys.readouterr()
         time_avg = logger_hook.time_sec_tot / (
             runner.iter + 1 - logger_hook.start_iter)
-        eta_second = time_avg * (runner.max_iters - runner.iter - 1)
+        eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
         eta_str = str(datetime.timedelta(seconds=int(eta_second)))
         if by_epoch:
             if torch.cuda.is_available():
@@ -337,10 +337,10 @@ class TestLoggerHook:
     def _setup_runner(self):
         runner = MagicMock()
         runner.epoch = 1
-        runner.data_loader = [0] * 5
+        runner.cur_dataloader = [0] * 5
         runner.inner_iter = 1
         runner.iter = 10
-        runner.max_iters = 50
+        runner.train_loop.max_iters = 50
         logger = logging.getLogger()
         logger.setLevel(logging.INFO)
         for handler in logger.handlers:
diff --git a/tests/test_hook/test_sampler_seed_hook.py b/tests/test_hook/test_sampler_seed_hook.py
index 0bf96743..9d19edf7 100644
--- a/tests/test_hook/test_sampler_seed_hook.py
+++ b/tests/test_hook/test_sampler_seed_hook.py
@@ -12,17 +12,17 @@ class TestDistSamplerSeedHook:
         # Test dataset sampler
         runner = Mock()
         runner.epoch = 1
-        runner.data_loader = Mock()
-        runner.data_loader.sampler = Mock()
-        runner.data_loader.sampler.set_epoch = Mock()
-        hook.before_epoch(runner)
-        runner.data_loader.sampler.set_epoch.assert_called()
+        runner.cur_dataloader = Mock()
+        runner.cur_dataloader.sampler = Mock()
+        runner.cur_dataloader.sampler.set_epoch = Mock()
+        hook.before_train_epoch(runner)
+        runner.cur_dataloader.sampler.set_epoch.assert_called()
         # Test batch sampler
         runner = Mock()
-        runner.data_loader = Mock()
-        runner.data_loader.sampler = Mock(spec_set=True)
-        runner.data_loader.batch_sampler = Mock()
-        runner.data_loader.batch_sampler.sampler = Mock()
-        runner.data_loader.batch_sampler.sampler.set_epoch = Mock()
-        hook.before_epoch(runner)
-        runner.data_loader.batch_sampler.sampler.set_epoch.assert_called()
+        runner.cur_dataloader = Mock()
+        runner.cur_dataloader.sampler = Mock(spec_set=True)
+        runner.cur_dataloader.batch_sampler = Mock()
+        runner.cur_dataloader.batch_sampler.sampler = Mock()
+        runner.cur_dataloader.batch_sampler.sampler.set_epoch = Mock()
+        hook.before_train_epoch(runner)
+        runner.cur_dataloader.batch_sampler.sampler.set_epoch.assert_called()
diff --git a/tests/test_hook/test_sync_buffers_hook.py b/tests/test_hook/test_sync_buffers_hook.py
index 6bba7de5..1c0b6295 100644
--- a/tests/test_hook/test_sync_buffers_hook.py
+++ b/tests/test_hook/test_sync_buffers_hook.py
@@ -10,4 +10,4 @@ class TestSyncBuffersHook:
         Runner = Mock()
         Runner.model = Mock()
         Hook = SyncBuffersHook()
-        Hook.after_epoch(Runner)
+        Hook._after_epoch(Runner)
-- 
GitLab