Skip to content
Snippets Groups Projects
Unverified Commit 15abb061 authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Fix]: Fix data batch type in base hook (#99)


* [Fix]: Fix data batch type in base hook

* [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks

Co-authored-by: default avatarYour <you@example.com>
parent 3adf4ea6
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
import os.path as osp
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample
from mmengine.fileio import FileClient
......@@ -179,14 +179,14 @@ class CheckpointHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
......@@ -49,31 +49,33 @@ class Hook:
pass
def before_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each iter.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
pass
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
pass
......@@ -153,59 +155,62 @@ class Hook:
self.after_epoch(runner)
def before_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each training iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def before_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each validation iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def before_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each test iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each training iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
......@@ -214,15 +219,15 @@ class Hook:
def after_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from
model. Defaults to None.
"""
......@@ -231,15 +236,15 @@ class Hook:
def after_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Tuple
import torch
from torch.nn.parameter import Parameter
......@@ -57,7 +57,7 @@ class OptimizerHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration.
......@@ -74,9 +74,10 @@ class OptimizerHook(Hook):
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. In order to keep this interface consistent with
other hooks, we keep ``data_batch`` here. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
......@@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook):
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. In
order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. In
order to keep this interface consistent with other hooks, we
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
for scheduler in runner.schedulers: # type: ignore
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment