From 15abb061ef148da66751c9293ee9e6a89eec0c59 Mon Sep 17 00:00:00 2001
From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com>
Date: Mon, 7 Mar 2022 13:25:45 +0800
Subject: [PATCH] [Fix]: Fix data batch type in base hook (#99)

* [Fix]: Fix data batch type in base hook

* [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks

Co-authored-by: Your <you@example.com>
---
 mmengine/hooks/checkpoint_hook.py      |  8 +--
 mmengine/hooks/hook.py                 | 73 ++++++++++++++------------
 mmengine/hooks/optimizer_hook.py       | 11 ++--
 mmengine/hooks/param_scheduler_hook.py | 16 +++---
 4 files changed, 58 insertions(+), 50 deletions(-)

diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py
index 14a7ab7b..89c742f3 100644
--- a/mmengine/hooks/checkpoint_hook.py
+++ b/mmengine/hooks/checkpoint_hook.py
@@ -2,7 +2,7 @@
 import os.path as osp
 import warnings
 from pathlib import Path
-from typing import Optional, Sequence, Union
+from typing import Any, Optional, Sequence, Tuple, Union
 
 from mmengine.data import BaseDataSample
 from mmengine.fileio import FileClient
@@ -179,14 +179,14 @@ class CheckpointHook(Hook):
     def after_train_iter(
             self,
             runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
             outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """Save the checkpoint and synchronize buffers after each iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample]): Data from dataloader.
-                Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
+                from dataloader. Defaults to None.
             outputs (Sequence[BaseDataSample], optional): Outputs from model.
                 Defaults to None.
         """
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
index f0ccb1f7..2582bc7b 100644
--- a/mmengine/hooks/hook.py
+++ b/mmengine/hooks/hook.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Sequence
+from typing import Any, Optional, Sequence, Tuple
 
 from mmengine.data import BaseDataSample
 
@@ -49,31 +49,33 @@ class Hook:
         pass
 
     def before_iter(
-            self,
-            runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        self,
+        runner: object,
+        data_batch: Optional[Sequence[Tuple[Any,
+                                            BaseDataSample]]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each iter.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample]): Data from dataloader.
-                Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
         """
         pass
 
     def after_iter(self,
                    runner: object,
-                   data_batch: Optional[Sequence[BaseDataSample]] = None,
+                   data_batch: Optional[Sequence[Tuple[
+                       Any, BaseDataSample]]] = None,
                    outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations after each epoch.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample]): Data from dataloader.
-                Defaults to None.
-            outputs (Sequence[BaseDataSample]): Outputs from model.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
                 Defaults to None.
         """
         pass
@@ -153,59 +155,62 @@ class Hook:
         self.after_epoch(runner)
 
     def before_train_iter(
-            self,
-            runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        self,
+        runner: object,
+        data_batch: Optional[Sequence[Tuple[Any,
+                                            BaseDataSample]]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each training iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
         """
         self.before_iter(runner, data_batch=None)
 
     def before_val_iter(
-            self,
-            runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        self,
+        runner: object,
+        data_batch: Optional[Sequence[Tuple[Any,
+                                            BaseDataSample]]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each validation iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
         """
         self.before_iter(runner, data_batch=None)
 
     def before_test_iter(
-            self,
-            runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
+        self,
+        runner: object,
+        data_batch: Optional[Sequence[Tuple[Any,
+                                            BaseDataSample]]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations before each test iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
         """
         self.before_iter(runner, data_batch=None)
 
     def after_train_iter(
             self,
             runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
             outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations after each training iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
             outputs (Sequence[BaseDataSample], optional): Outputs from model.
                 Defaults to None.
         """
@@ -214,15 +219,15 @@ class Hook:
     def after_val_iter(
             self,
             runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
             outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations after each validation iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
             outputs (Sequence[BaseDataSample], optional): Outputs from
                 model. Defaults to None.
         """
@@ -231,15 +236,15 @@ class Hook:
     def after_test_iter(
             self,
             runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
             outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """All subclasses should override this method, if they need any
         operations after each test iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
+                Data from dataloader. Defaults to None.
             outputs (Sequence[BaseDataSample], optional): Outputs from model.
                 Defaults to None.
         """
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
index 99f010ab..a22cb1bb 100644
--- a/mmengine/hooks/optimizer_hook.py
+++ b/mmengine/hooks/optimizer_hook.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import logging
-from typing import List, Optional, Sequence
+from typing import Any, List, Optional, Sequence, Tuple
 
 import torch
 from torch.nn.parameter import Parameter
@@ -57,7 +57,7 @@ class OptimizerHook(Hook):
     def after_train_iter(
             self,
             runner: object,
-            data_batch: Optional[Sequence[BaseDataSample]] = None,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
             outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """All operations need to be finished after each training iteration.
 
@@ -74,9 +74,10 @@ class OptimizerHook(Hook):
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample], optional): Data from
-                dataloader. In order to keep this interface consistent with
-                other hooks, we keep ``data_batch`` here. Defaults to None.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
+                from dataloader. In order to keep this interface consistent
+                with other hooks, we keep ``data_batch`` here.
+                Defaults to None.
             outputs (Sequence[BaseDataSample], optional): Outputs from model.
                 In order to keep this interface consistent with other hooks,
                 we keep ``outputs`` here. Defaults to None.
diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py
index 425ab123..e99d9251 100644
--- a/mmengine/hooks/param_scheduler_hook.py
+++ b/mmengine/hooks/param_scheduler_hook.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Sequence
+from typing import Any, Optional, Sequence, Tuple
 
 from mmengine.data import BaseDataSample
 from mmengine.registry import HOOKS
@@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook):
 
     def after_iter(self,
                    runner: object,
-                   data_batch: Optional[Sequence[BaseDataSample]] = None,
+                   data_batch: Optional[Sequence[Tuple[
+                       Any, BaseDataSample]]] = None,
                    outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
         """Call step function for each scheduler after each iteration.
 
         Args:
             runner (object): The runner of the training process.
-            data_batch (Sequence[BaseDataSample]): Data from dataloader. In
-                order to keep this interface consistent with other hooks, we
-                keep ``data_batch`` here. Defaults to None.
-            outputs (Sequence[BaseDataSample]): Outputs from model. In
-                order to keep this interface consistent with other hooks, we
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
+                from dataloader. In order to keep this interface consistent
+                with other hooks, we keep ``data_batch`` here.
+                Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                In order to keep this interface consistent with other hooks, we
                 keep ``data_batch`` here. Defaults to None.
         """
         for scheduler in runner.schedulers:  # type: ignore
-- 
GitLab