Skip to content
Snippets Groups Projects
Unverified Commit ec3034b7 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix output argument of after_iter, train_after_ter and val_after_iter (#115)

* Fix hook

* Fix

* Fix docs

* FIx

* Fix

* Fix as comment
parent 3bdd27c4
No related branches found
No related tags found
No related merge requests found
......@@ -168,18 +168,17 @@ class CheckpointHook(Hook):
else:
break
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_train_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs=Optional[dict]) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
if self.by_epoch:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, Union
import torch
......@@ -37,14 +37,16 @@ class EmptyCacheHook(Hook):
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
-> None:
"""Empty cache after an iteration.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
"""
if self._after_iter:
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample
......@@ -19,7 +19,8 @@ class Hook:
operations before the training process.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training/validation/testing
process.
"""
pass
......@@ -27,11 +28,66 @@ class Hook:
"""All subclasses should override this method, if they need any
operations after the training process.
Args:
runner (Runner): The runner of the training/validation/testing
process.
"""
pass
def before_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before train.
Args:
runner (Runner): The runner of the training process.
"""
pass
def after_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after train.
Args:
runner (Runner): The runner of the training process.
"""
pass
def before_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before val.
Args:
runner (Runner): The runner of the validation process.
"""
pass
def after_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after val.
Args:
runner (Runner): The runner of the validation process.
"""
pass
def before_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before test.
Args:
runner (Runner): The runner of the testing process.
"""
pass
def after_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after test.
Args:
runner (Runner): The runner of the testing process.
"""
pass
def before_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each epoch.
......@@ -64,7 +120,9 @@ class Hook:
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
......@@ -72,8 +130,8 @@ class Hook:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
"""
pass
......@@ -184,11 +242,10 @@ class Hook:
"""
self.before_iter(runner, data_batch=None)
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_train_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each training iteration.
......@@ -196,16 +253,16 @@ class Hook:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
def after_val_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_val_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) \
-> None:
"""All subclasses should override this method, if they need any
operations after each validation iteration.
......@@ -213,7 +270,7 @@ class Hook:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
......@@ -230,7 +287,7 @@ class Hook:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self.after_iter(runner, data_batch=None, outputs=None)
......
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Any, Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
......@@ -40,15 +40,17 @@ class IterTimerHook(Hook):
def after_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-> None:
"""Logging time for a iteration and update the time flag.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
"""
# TODO: update for new logging system
runner.log_buffer.update({'time': time.time() - self.t})
......
......@@ -171,18 +171,17 @@ class LoggerHook(Hook):
if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_train_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Record training logs.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
if runner.meta is not None and 'exp_name' in runner.meta:
......
......@@ -56,11 +56,10 @@ class OptimizerHook(Hook):
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
return None
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_train_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""All operations need to be finished after each training iteration.
This function will finish following 3 operations:
......@@ -80,7 +79,7 @@ class OptimizerHook(Hook):
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
"""
......
......@@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook):
priority = 'LOW'
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
def after_train_iter(self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Call step function for each scheduler after each iteration.
Args:
......@@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook):
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment