Skip to content
Snippets Groups Projects
Unverified Commit 53101a1a authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Docs] Refine hook documentation (#181)

* Modify hook documentation

* reslove comments
parent ecf816e1
No related branches found
No related tags found
No related merge requests found
...@@ -135,39 +135,39 @@ def main(): ...@@ -135,39 +135,39 @@ def main():
def main(): def main():
... ...
call_hooks('before_run', hooks) # 训练开始前执行的逻辑 call_hooks('before_run', hooks) # 训练开始前执行的逻辑
call_hooks('after_load_checkpoint') # 加载权重后执行的逻辑 call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
for i in range(max_epochs): for i in range(max_epochs):
call_hooks('before_train_epoch') # 遍历训练数据集前执行的逻辑 call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
for inputs, labels in train_dataloader: for inputs, labels in train_dataloader:
call_hooks('before_train_iter') # 模型前向计算前执行的逻辑 call_hooks('before_train_iter', hooks) # 模型前向计算前执行的逻辑
outputs = net(inputs) outputs = net(inputs)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
call_hooks('after_train_iter') # 模型前向计算后执行的逻辑 call_hooks('after_train_iter', hooks) # 模型前向计算后执行的逻辑
loss.backward() loss.backward()
optimizer.step() optimizer.step()
call_hooks('after_train_epoch') # 遍历完训练数据集后执行的逻辑 call_hooks('after_train_epoch', hooks) # 遍历完训练数据集后执行的逻辑
call_hooks('before_val_epoch') # 遍历验证数据集前执行的逻辑 call_hooks('before_val_epoch', hooks) # 遍历验证数据集前执行的逻辑
with torch.no_grad(): with torch.no_grad():
for inputs, labels in val_dataloader: for inputs, labels in val_dataloader:
call_hooks('before_val_iter') # 模型前向计算前执行 call_hooks('before_val_iter', hooks) # 模型前向计算前执行
outputs = net(inputs) outputs = net(inputs)
loss = criterion(outputs, labels) loss = criterion(outputs, labels)
call_hooks('after_val_iter') # 模型前向计算后执行 call_hooks('after_val_iter', hooks) # 模型前向计算后执行
call_hooks('after_val_epoch') # 遍历完验证数据集前执行 call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
call_hooks('before_save_checkpoint') # 保存权重前执行的逻辑 call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
call_hooks('before_test_epoch') # 遍历测试数据集前执行的逻辑 call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
with torch.no_grad(): with torch.no_grad():
for inputs, labels in test_dataloader: for inputs, labels in test_dataloader:
call_hooks('before_test_iter') # 模型前向计算后执行的逻辑 call_hooks('before_test_iter', hooks) # 模型前向计算后执行的逻辑
outputs = net(inputs) outputs = net(inputs)
accuracy = ... accuracy = ...
call_hooks('after_test_iter') # 遍历完成测试数据集后执行的逻辑 call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
call_hooks('after_test_epoch') # 遍历完测试数据集后执行 call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
call_hooks('after_run') # 训练结束后执行的逻辑 call_hooks('after_run', hooks) # 训练结束后执行的逻辑
``` ```
在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html) 在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html)
...@@ -221,6 +221,7 @@ from mmengine import Runner ...@@ -221,6 +221,7 @@ from mmengine import Runner
default_hooks = dict( default_hooks = dict(
optimizer=dict(type='OptimizerHook'), optimizer=dict(type='OptimizerHook'),
timer=dict(type='IterTimerHook', timer=dict(type='IterTimerHook',
sampler_seed=dict(type='DistSamplerSeedHook'),
logger=dict(type='TextLoggerHook'), logger=dict(type='TextLoggerHook'),
param_scheduler=dict(type='ParamSchedulerHook')), param_scheduler=dict(type='ParamSchedulerHook')),
checkpoint=dict(type='CheckpointHook', interval=1) checkpoint=dict(type='CheckpointHook', interval=1)
...@@ -384,7 +385,7 @@ config = dict(type='SyncBuffersHook') ...@@ -384,7 +385,7 @@ config = dict(type='SyncBuffersHook')
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。 如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 `after_train_iter` 位点。 例如,如果希望在训练的过程中判断损失值是否有效,如果值为无穷大则无效,我们可以在每次迭代后判断损失值是否无穷大,因此只需重写 `after_train_iter` 位点。
```python ```python
import torch import torch
...@@ -407,7 +408,18 @@ class CheckInvalidLossHook(Hook): ...@@ -407,7 +408,18 @@ class CheckInvalidLossHook(Hook):
def __init__(self, interval=50): def __init__(self, interval=50):
self.interval = interval self.interval = interval
def after_train_iter(self, runner, data_batch): def after_train_iter(self, runner, batch_idx, data_batch=None, outputs=None):
"""All subclasses should override this method, if they need any
operations after each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
if self.every_n_iters(runner, self.interval): if self.every_n_iters(runner, self.interval):
assert torch.isfinite(runner.outputs['loss']), \ assert torch.isfinite(runner.outputs['loss']), \
runner.logger.info('loss become infinite or NaN!') runner.logger.info('loss become infinite or NaN!')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment