Skip to content
Snippets Groups Projects
Unverified Commit ed8dcb4c authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

fix type hint in hooks (#106)

parent 9f0d1a96
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,8 @@ from mmengine.fileio import FileClient
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class CheckpointHook(Hook):
......@@ -65,7 +67,7 @@ class CheckpointHook(Hook):
self.sync_buffer = sync_buffer
self.file_client_args = file_client_args
def before_run(self, runner: object) -> None:
def before_run(self, runner) -> None:
"""Finish all operations, related to checkpoint.
This function will get the appropriate file client, and the directory
......@@ -75,7 +77,7 @@ class CheckpointHook(Hook):
runner (Runner): The runner of the training process.
"""
if not self.out_dir:
self.out_dir = runner.work_dir # type: ignore
self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir)
......@@ -84,17 +86,13 @@ class CheckpointHook(Hook):
# `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir`
if self.out_dir != runner.work_dir: # type: ignore
basename = osp.basename(
runner.work_dir.rstrip( # type: ignore
osp.sep))
if self.out_dir != runner.work_dir:
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(
self.out_dir, # type: ignore
basename)
self.out_dir, basename) # type: ignore # noqa: E501
runner.logger.info(( # type: ignore
f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not
# allow to create a symlink
......@@ -109,7 +107,7 @@ class CheckpointHook(Hook):
else:
self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner: object) -> None:
def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch.
Args:
......@@ -124,46 +122,40 @@ class CheckpointHook(Hook):
if self.every_n_epochs(
runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)):
runner.logger.info( # type: ignore
f'Saving checkpoint at \
{runner.epoch + 1} epochs') # type: ignore
runner.logger.info(f'Saving checkpoint at \
{runner.epoch + 1} epochs')
if self.sync_buffer:
pass
# TODO
self._save_checkpoint(runner)
# TODO Add master_only decorator
def _save_checkpoint(self, runner: object) -> None:
def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint.
Args:
runner (Runner): The runner of the training process.
"""
runner.save_checkpoint( # type: ignore
self.out_dir,
save_optimizer=self.save_optimizer,
**self.args)
if runner.meta is not None: # type: ignore
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
if runner.meta is not None:
if self.by_epoch:
cur_ckpt_filename = self.args.get(
'filename_tmpl',
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
else:
cur_ckpt_filename = self.args.get(
'filename_tmpl',
'iter_{}.pth').format(runner.iter + 1) # type: ignore
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
runner.meta['hook_msgs'][ # type: ignore
'last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
runner.meta.setdefault('hook_msgs', dict())
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore
# remove other checkpoints
if self.max_keep_ckpts > 0:
if self.by_epoch:
name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1 # type: ignore
current_ckpt = runner.epoch + 1
else:
name = 'iter_{}.pth'
current_ckpt = runner.iter + 1 # type: ignore
current_ckpt = runner.iter + 1
redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval)
......@@ -178,8 +170,8 @@ class CheckpointHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
......@@ -199,9 +191,8 @@ class CheckpointHook(Hook):
if self.every_n_iters(
runner, self.interval) or (self.save_last
and self.is_last_iter(runner)):
runner.logger.info( # type: ignore
f'Saving checkpoint at \
{runner.iter + 1} iterations') # type: ignore
runner.logger.info(f'Saving checkpoint at \
{runner.iter + 1} iterations')
if self.sync_buffer:
pass
# TODO
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
import torch
......@@ -7,6 +7,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class EmptyCacheHook(Hook):
......@@ -33,35 +35,35 @@ class EmptyCacheHook(Hook):
self._after_iter = after_iter
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Empty cache after an iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None.
"""
if self._after_iter:
torch.cuda.empty_cache()
def before_epoch(self, runner: object) -> None:
def before_epoch(self, runner) -> None:
"""Empty cache before an epoch.
Args:
runner (object): The runner of the training process.
runner (Runner): The runner of the training process.
"""
if self._before_epoch:
torch.cuda.empty_cache()
def after_epoch(self, runner: object) -> None:
def after_epoch(self, runner) -> None:
"""Empty cache after an epoch.
Args:
runner (object): The runner of the training process.
runner (Runner): The runner of the training process.
"""
if self._after_epoch:
torch.cuda.empty_cache()
......@@ -3,6 +3,8 @@ from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
class Hook:
"""Base hook class.
......@@ -12,7 +14,7 @@ class Hook:
priority = 'NORMAL'
def before_run(self, runner: object) -> None:
def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before the training process.
......@@ -21,7 +23,7 @@ class Hook:
"""
pass
def after_run(self, runner: object) -> None:
def after_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after the training process.
......@@ -30,7 +32,7 @@ class Hook:
"""
pass
def before_epoch(self, runner: object) -> None:
def before_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each epoch.
......@@ -39,7 +41,7 @@ class Hook:
"""
pass
def after_epoch(self, runner: object) -> None:
def after_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
......@@ -48,11 +50,7 @@ class Hook:
"""
pass
def before_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each iter.
......@@ -64,9 +62,8 @@ class Hook:
pass
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
......@@ -80,7 +77,7 @@ class Hook:
"""
pass
def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None:
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
operations before saving the checkpoint.
......@@ -90,7 +87,7 @@ class Hook:
"""
pass
def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None:
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
operations after loading the checkpoint.
......@@ -100,7 +97,7 @@ class Hook:
"""
pass
def before_train_epoch(self, runner: object) -> None:
def before_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each training epoch.
......@@ -109,7 +106,7 @@ class Hook:
"""
self.before_epoch(runner)
def before_val_epoch(self, runner: object) -> None:
def before_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each validation epoch.
......@@ -118,7 +115,7 @@ class Hook:
"""
self.before_epoch(runner)
def before_test_epoch(self, runner: object) -> None:
def before_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before each test epoch.
......@@ -127,7 +124,7 @@ class Hook:
"""
self.before_epoch(runner)
def after_train_epoch(self, runner: object) -> None:
def after_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each training epoch.
......@@ -136,7 +133,7 @@ class Hook:
"""
self.after_epoch(runner)
def after_val_epoch(self, runner: object) -> None:
def after_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each validation epoch.
......@@ -145,7 +142,7 @@ class Hook:
"""
self.after_epoch(runner)
def after_test_epoch(self, runner: object) -> None:
def after_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after each test epoch.
......@@ -154,11 +151,7 @@ class Hook:
"""
self.after_epoch(runner)
def before_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each training iteration.
......@@ -169,11 +162,7 @@ class Hook:
"""
self.before_iter(runner, data_batch=None)
def before_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each validation iteration.
......@@ -184,11 +173,7 @@ class Hook:
"""
self.before_iter(runner, data_batch=None)
def before_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""All subclasses should override this method, if they need any
operations before each test iteration.
......@@ -201,8 +186,8 @@ class Hook:
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each training iteration.
......@@ -218,8 +203,8 @@ class Hook:
def after_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation iteration.
......@@ -235,8 +220,8 @@ class Hook:
def after_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test iteration.
......@@ -250,7 +235,7 @@ class Hook:
"""
self.after_iter(runner, data_batch=None, outputs=None)
def every_n_epochs(self, runner: object, n: int) -> bool:
def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether or not current epoch can be evenly divided by n.
Args:
......@@ -260,9 +245,9 @@ class Hook:
Returns:
bool: whether or not current epoch can be evenly divided by n.
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner: object, n: int) -> bool:
def every_n_inner_iters(self, runner, n: int) -> bool:
"""Test whether or not current inner iteration can be evenly divided by
n.
......@@ -275,10 +260,9 @@ class Hook:
bool: whether or not current inner iteration can be evenly
divided by n.
"""
return (runner.inner_iter + # type: ignore
1) % n == 0 if n > 0 else False
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner: object, n: int) -> bool:
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether or not current iteration can be evenly divided by n.
Args:
......@@ -290,9 +274,9 @@ class Hook:
bool: Return True if the current iteration can be evenly divided
by n, otherwise False.
"""
return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner: object) -> bool:
def end_of_epoch(self, runner) -> bool:
"""Check whether the current epoch reaches the `max_epochs` or not.
Args:
......@@ -301,9 +285,9 @@ class Hook:
Returns:
bool: whether the end of current epoch or not.
"""
return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore
return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner: object) -> bool:
def is_last_epoch(self, runner) -> bool:
"""Test whether or not current epoch is the last epoch.
Args:
......@@ -313,9 +297,9 @@ class Hook:
bool: bool: Return True if the current epoch reaches the
`max_epochs`, otherwise False.
"""
return runner.epoch + 1 == runner._max_epochs # type: ignore
return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner: object) -> bool:
def is_last_iter(self, runner) -> bool:
"""Test whether or not current epoch is the last iteration.
Args:
......@@ -324,4 +308,4 @@ class Hook:
Returns:
bool: whether or not current iteration is the last iteration.
"""
return runner.iter + 1 == runner._max_iters # type: ignore
return runner.iter + 1 == runner._max_iters
# Copyright (c) OpenMMLab. All rights reserved.
import time
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class IterTimerHook(Hook):
......@@ -16,45 +18,38 @@ class IterTimerHook(Hook):
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None:
def before_epoch(self, runner) -> None:
"""Record time flag before start a epoch.
Args:
runner (object): The runner of the training process.
runner (Runner): The runner of the training process.
"""
self.t = time.time()
def before_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
"""Logging time for loading data and update the time flag.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
"""
# TODO: update for new logging system
runner.log_buffer.update({ # type: ignore
'data_time': time.time() - self.t
})
runner.log_buffer.update({'data_time': time.time() - self.t})
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Logging time for a iteration and update the time flag.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None.
"""
# TODO: update for new logging system
runner.log_buffer.update({ # type: ignore
'time': time.time() - self.t
})
runner.log_buffer.update({'time': time.time() - self.t})
self.t = time.time()
......@@ -10,6 +10,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class OptimizerHook(Hook):
......@@ -56,8 +58,8 @@ class OptimizerHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration.
......@@ -82,32 +84,27 @@ class OptimizerHook(Hook):
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad() # type: ignore
runner.optimizer.zero_grad()
if self.detect_anomalous_params:
self.detect_anomalous_parameters(
runner.outputs['loss'], # type: ignore
runner)
runner.outputs['loss'].backward() # type: ignore
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
if self.grad_clip is not None:
grad_norm = self.clip_grads(
runner.model.parameters()) # type: ignore
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update( # type: ignore
{'grad_norm': float(grad_norm)},
runner.outputs['num_samples']) # type: ignore
runner.optimizer.step() # type: ignore
runner.log_buffer.update({'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.step()
def detect_anomalous_parameters(self, loss: torch.Tensor,
runner: object) -> None:
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
"""Detect anomalous parameters that are not included in the graph.
Args:
loss (torch.Tensor): The loss of current iteration.
runner (Runner): The runner of the training process.
"""
logger = runner.logger # type: ignore
logger = runner.logger
parameters_in_graph = set()
visited = set()
......@@ -125,7 +122,7 @@ class OptimizerHook(Hook):
traverse(grad_fn)
traverse(loss.grad_fn)
for n, p in runner.model.named_parameters(): # type: ignore
for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
logger.log(
level=logging.ERROR,
......
......@@ -5,6 +5,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class ParamSchedulerHook(Hook):
......@@ -15,8 +17,8 @@ class ParamSchedulerHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration.
......@@ -30,16 +32,16 @@ class ParamSchedulerHook(Hook):
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
for scheduler in runner.schedulers: # type: ignore
for scheduler in runner.schedulers:
if not scheduler.by_epoch:
scheduler.step()
def after_train_epoch(self, runner: object) -> None:
def after_train_epoch(self, runner) -> None:
"""Call step function for each scheduler after each epoch.
Args:
runner (Runner): The runner of the training process.
"""
for scheduler in runner.schedulers: # type: ignore
for scheduler in runner.schedulers:
if scheduler.by_epoch:
scheduler.step()
......@@ -14,18 +14,15 @@ class DistSamplerSeedHook(Hook):
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None:
def before_epoch(self, runner) -> None:
"""Set the seed for sampler and batch_sampler.
Args:
runner (Runner): The runner of the training process.
"""
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore
if hasattr(runner.data_loader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore
elif hasattr(
runner.data_loader.batch_sampler.sampler, # type: ignore
'set_epoch'):
runner.data_loader.sampler.set_epoch(runner.epoch)
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore
runner.epoch) # type: ignore
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
......@@ -89,11 +89,11 @@ class SyncBuffersHook(Hook):
def __init__(self) -> None:
self.distributed = dist.IS_DIST
def after_epoch(self, runner: object) -> None:
def after_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch.
Args:
runner (object): The runner of the training process.
runner (Runner): The runner of the training process.
"""
if self.distributed:
allreduce_params(runner.model.buffers()) # type: ignore
allreduce_params(runner.model.buffers())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment