From 15abb061ef148da66751c9293ee9e6a89eec0c59 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Mon, 7 Mar 2022 13:25:45 +0800 Subject: [PATCH] [Fix]: Fix data batch type in base hook (#99) * [Fix]: Fix data batch type in base hook * [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks Co-authored-by: Your <you@example.com> --- mmengine/hooks/checkpoint_hook.py | 8 +-- mmengine/hooks/hook.py | 73 ++++++++++++++------------ mmengine/hooks/optimizer_hook.py | 11 ++-- mmengine/hooks/param_scheduler_hook.py | 16 +++--- 4 files changed, 58 insertions(+), 50 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 14a7ab7b..89c742f3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -2,7 +2,7 @@ import os.path as osp import warnings from pathlib import Path -from typing import Optional, Sequence, Union +from typing import Any, Optional, Sequence, Tuple, Union from mmengine.data import BaseDataSample from mmengine.fileio import FileClient @@ -179,14 +179,14 @@ class CheckpointHook(Hook): def after_train_iter( self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Save the checkpoint and synchronize buffers after each iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. Defaults to None. """ diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index f0ccb1f7..2582bc7b 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Sequence +from typing import Any, Optional, Sequence, Tuple from mmengine.data import BaseDataSample @@ -49,31 +49,33 @@ class Hook: pass def before_iter( - self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + self, + runner: object, + data_batch: Optional[Sequence[Tuple[Any, + BaseDataSample]]] = None) -> None: """All subclasses should override this method, if they need any operations before each iter. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. """ pass def after_iter(self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[ + Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each epoch. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. - Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. Defaults to None. """ pass @@ -153,59 +155,62 @@ class Hook: self.after_epoch(runner) def before_train_iter( - self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + self, + runner: object, + data_batch: Optional[Sequence[Tuple[Any, + BaseDataSample]]] = None) -> None: """All subclasses should override this method, if they need any operations before each training iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. """ self.before_iter(runner, data_batch=None) def before_val_iter( - self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + self, + runner: object, + data_batch: Optional[Sequence[Tuple[Any, + BaseDataSample]]] = None) -> None: """All subclasses should override this method, if they need any operations before each validation iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. """ self.before_iter(runner, data_batch=None) def before_test_iter( - self, - runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: + self, + runner: object, + data_batch: Optional[Sequence[Tuple[Any, + BaseDataSample]]] = None) -> None: """All subclasses should override this method, if they need any operations before each test iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. """ self.before_iter(runner, data_batch=None) def after_train_iter( self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each training iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. Defaults to None. """ @@ -214,15 +219,15 @@ class Hook: def after_val_iter( self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each validation iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. Defaults to None. """ @@ -231,15 +236,15 @@ class Hook: def after_test_iter( self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All subclasses should override this method, if they need any operations after each test iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): + Data from dataloader. Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. Defaults to None. """ diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index 99f010ab..a22cb1bb 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence, Tuple import torch from torch.nn.parameter import Parameter @@ -57,7 +57,7 @@ class OptimizerHook(Hook): def after_train_iter( self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """All operations need to be finished after each training iteration. @@ -74,9 +74,10 @@ class OptimizerHook(Hook): Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample], optional): Data from - dataloader. In order to keep this interface consistent with - other hooks, we keep ``data_batch`` here. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. In order to keep this interface consistent + with other hooks, we keep ``data_batch`` here. + Defaults to None. outputs (Sequence[BaseDataSample], optional): Outputs from model. In order to keep this interface consistent with other hooks, we keep ``outputs`` here. Defaults to None. diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 425ab123..e99d9251 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Sequence +from typing import Any, Optional, Sequence, Tuple from mmengine.data import BaseDataSample from mmengine.registry import HOOKS @@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook): def after_iter(self, runner: object, - data_batch: Optional[Sequence[BaseDataSample]] = None, + data_batch: Optional[Sequence[Tuple[ + Any, BaseDataSample]]] = None, outputs: Optional[Sequence[BaseDataSample]] = None) -> None: """Call step function for each scheduler after each iteration. Args: runner (object): The runner of the training process. - data_batch (Sequence[BaseDataSample]): Data from dataloader. In - order to keep this interface consistent with other hooks, we - keep ``data_batch`` here. Defaults to None. - outputs (Sequence[BaseDataSample]): Outputs from model. In - order to keep this interface consistent with other hooks, we + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. In order to keep this interface consistent + with other hooks, we keep ``data_batch`` here. + Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + In order to keep this interface consistent with other hooks, we keep ``data_batch`` here. Defaults to None. """ for scheduler in runner.schedulers: # type: ignore -- GitLab