From ec3034b7650b79fb9d7f692d2b9384403adbcc8e Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 9 Mar 2022 23:10:19 +0800
Subject: [PATCH] [Fix] Fix output argument of after_iter, train_after_ter and
 val_after_iter (#115)

* Fix hook

* Fix

* Fix docs

* FIx

* Fix

* Fix as comment
---
 mmengine/hooks/checkpoint_hook.py      | 11 ++-
 mmengine/hooks/empty_cache_hook.py     |  8 ++-
 mmengine/hooks/hook.py                 | 93 +++++++++++++++++++++-----
 mmengine/hooks/iter_timer_hook.py      | 10 +--
 mmengine/hooks/logger_hook.py          | 11 ++-
 mmengine/hooks/optimizer_hook.py       | 11 ++-
 mmengine/hooks/param_scheduler_hook.py | 11 ++-
 7 files changed, 106 insertions(+), 49 deletions(-)

diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index a312307f..d8c75d1f 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -168,18 +168,17 @@ class CheckpointHook(Hook):
                 else:
                     break
 
-    def after_train_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_train_iter(self,
+                         runner,
+                         data_batch: DATA_BATCH = None,
+                         outputs=Optional[dict]) -> None:
         """Save the checkpoint and synchronize buffers after each iteration.
 
         Args:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
                 from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
         if self.by_epoch:
diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py
index 2f621131..a37de337 100644
--- a/mmengine/hooks/empty_cache_hook.py
+++ b/mmengine/hooks/empty_cache_hook.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Optional, Sequence, Tuple
+from typing import Any, Optional, Sequence, Tuple, Union
 
 import torch
 
@@ -37,14 +37,16 @@ class EmptyCacheHook(Hook):
     def after_iter(self,
                    runner,
                    data_batch: DATA_BATCH = None,
-                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+                   outputs:
+                   Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
+            -> None:
         """Empty cache after an iteration.
 
         Args:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
                 from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample]): Outputs from model.
+            outputs (dict or sequence, optional): Outputs from model.
                 Defaults to None.
         """
         if self._after_iter:
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
index f2d52fc9..91729375 100644
--- a/mmengine/hooks/hook.py
+++ b/mmengine/hooks/hook.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, Optional, Sequence, Tuple
+from typing import Any, Optional, Sequence, Tuple, Union
 
 from mmengine.data import BaseDataSample
 
@@ -19,7 +19,8 @@ class Hook:
         operations before the training process.
 
         Args:
-            runner (Runner): The runner of the training process.
+            runner (Runner): The runner of the training/validation/testing
+                process.
         """
         pass
 
@@ -27,11 +28,66 @@ class Hook:
         """All subclasses should override this method, if they need any
         operations after the training process.
 
+        Args:
+            runner (Runner): The runner of the training/validation/testing
+                process.
+        """
+        pass
+
+    def before_train(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations before train.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        pass
+
+    def after_train(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations after train.
+
         Args:
             runner (Runner): The runner of the training process.
         """
         pass
 
+    def before_val(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations before val.
+
+        Args:
+            runner (Runner): The runner of the validation process.
+        """
+        pass
+
+    def after_val(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations after val.
+
+        Args:
+            runner (Runner): The runner of the validation process.
+        """
+        pass
+
+    def before_test(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations before test.
+
+        Args:
+            runner (Runner): The runner of the testing process.
+        """
+        pass
+
+    def after_test(self, runner) -> None:
+        """All subclasses should override this method, if they need any
+        operations after test.
+
+        Args:
+            runner (Runner): The runner of the testing process.
+        """
+        pass
+
     def before_epoch(self, runner) -> None:
         """All subclasses should override this method, if they need any
         operations before each epoch.
@@ -64,7 +120,9 @@ class Hook:
     def after_iter(self,
                    runner,
                    data_batch: DATA_BATCH = None,
-                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+                   outputs:
+                   Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
+            -> None:
         """All subclasses should override this method, if they need any
         operations after each epoch.
 
@@ -72,8 +130,8 @@ class Hook:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
-                Defaults to None.
+            outputs (dict or sequence, optional): Outputs from model. Defaults
+                to None.
         """
         pass
 
@@ -184,11 +242,10 @@ class Hook:
         """
         self.before_iter(runner, data_batch=None)
 
-    def after_train_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_train_iter(self,
+                         runner,
+                         data_batch: DATA_BATCH = None,
+                         outputs: Optional[dict] = None) -> None:
         """All subclasses should override this method, if they need any
         operations after each training iteration.
 
@@ -196,16 +253,16 @@ class Hook:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
         self.after_iter(runner, data_batch=None, outputs=None)
 
-    def after_val_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_val_iter(self,
+                       runner,
+                       data_batch: DATA_BATCH = None,
+                       outputs: Optional[Sequence[BaseDataSample]] = None) \
+            -> None:
         """All subclasses should override this method, if they need any
         operations after each validation iteration.
 
@@ -213,7 +270,7 @@ class Hook:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from
+            outputs (dict or sequence, optional): Outputs from
                 model. Defaults to None.
         """
         self.after_iter(runner, data_batch=None, outputs=None)
@@ -230,7 +287,7 @@ class Hook:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
                 Data from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
         self.after_iter(runner, data_batch=None, outputs=None)
diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py
index d1d6404f..72701504 100644
--- a/mmengine/hooks/iter_timer_hook.py
+++ b/mmengine/hooks/iter_timer_hook.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import time
-from typing import Any, Optional, Sequence, Tuple
+from typing import Any, Optional, Sequence, Tuple, Union
 
 from mmengine.data import BaseDataSample
 from mmengine.registry import HOOKS
@@ -40,15 +40,17 @@ class IterTimerHook(Hook):
     def after_iter(self,
                    runner,
                    data_batch: DATA_BATCH = None,
-                   outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+                   outputs:
+                   Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
+            -> None:
         """Logging time for a iteration and update the time flag.
 
         Args:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
                 from dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample]): Outputs from model.
-                Defaults to None.
+            outputs (dict or sequence, optional): Outputs from model. Defaults
+                to None.
         """
         # TODO: update for new logging system
         runner.log_buffer.update({'time': time.time() - self.t})
diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index 741bfd95..b557c2d0 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -171,18 +171,17 @@ class LoggerHook(Hook):
         if runner.meta is not None:
             runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
 
-    def after_train_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_train_iter(self,
+                         runner,
+                         data_batch: DATA_BATCH = None,
+                         outputs: Optional[dict] = None) -> None:
         """Record training logs.
 
         Args:
             runner (Runner): The runner of the training process.
             data_batch (Sequence[BaseDataSample], optional): Data from
                 dataloader. Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 Defaults to None.
         """
         if runner.meta is not None and 'exp_name' in runner.meta:
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
index 3a9dabbc..418bf06d 100644
--- a/mmengine/hooks/optimizer_hook.py
+++ b/mmengine/hooks/optimizer_hook.py
@@ -56,11 +56,10 @@ class OptimizerHook(Hook):
             return clip_grad.clip_grad_norm_(params, **self.grad_clip)
         return None
 
-    def after_train_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_train_iter(self,
+                         runner,
+                         data_batch: DATA_BATCH = None,
+                         outputs: Optional[dict] = None) -> None:
         """All operations need to be finished after each training iteration.
 
         This function will finish following 3 operations:
@@ -80,7 +79,7 @@ class OptimizerHook(Hook):
                 from dataloader. In order to keep this interface consistent
                 with other hooks, we keep ``data_batch`` here.
                 Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 In order to keep this interface consistent with other hooks,
                 we keep ``outputs`` here. Defaults to None.
         """
diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py
index 095b1bb3..a85ef3ac 100644
--- a/mmengine/hooks/param_scheduler_hook.py
+++ b/mmengine/hooks/param_scheduler_hook.py
@@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook):
 
     priority = 'LOW'
 
-    def after_train_iter(
-            self,
-            runner,
-            data_batch: DATA_BATCH = None,
-            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+    def after_train_iter(self,
+                         runner,
+                         data_batch: DATA_BATCH = None,
+                         outputs: Optional[dict] = None) -> None:
         """Call step function for each scheduler after each iteration.
 
         Args:
@@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook):
                 from dataloader. In order to keep this interface consistent
                 with other hooks, we keep ``data_batch`` here.
                 Defaults to None.
-            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+            outputs (dict, optional): Outputs from model.
                 In order to keep this interface consistent with other hooks, we
                 keep ``data_batch`` here. Defaults to None.
         """
-- 
GitLab