Skip to content
Snippets Groups Projects
Unverified Commit 6073d9eb authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] add documents for `clip_grad` , and support clip grad by value. (#513)

* [Enhance] add documents for , and support clip grad by value

* refine docstring

* fix as comment

* Fix as comment

* minor refine

* minor refine

* remove error comment for clip grad

* refine docstring
parent 4111cfb5
No related branches found
No related tags found
No related merge requests found
......@@ -132,6 +132,18 @@ for idx, (input, target) in enumerate(zip(inputs, targets)):
optim_wrapper.zero_grad()
```
我们同样可以为优化器封装配置梯度裁减策略:
```python
# 基于 torch.nn.utils.clip_grad_norm_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, clip_grad=dict(max_norm=1))
# 基于 torch.nn.utils.clip_grad_value_ 对梯度进行裁减
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, clip_grad=dict(clip_value=0.2))
```
### 获取学习率/动量:
优化器封装提供了 `get_lr``get_momentum` 接口用于获取优化器的一个参数组的学习率
......
......@@ -5,7 +5,6 @@ from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad
from torch.optim import Optimizer
from mmengine.logging import MessageHub, print_log
......@@ -32,7 +31,27 @@ class OptimWrapper:
gradients. The parameters will be updated per
``accumulative_counts``.
clip_grad (dict, optional): If ``clip_grad`` is not None, it will be
the arguments of ``torch.nn.utils.clip_grad``.
the arguments of :func:`torch.nn.utils.clip_grad_norm_` or
:func:`torch.nn.utils.clip_grad_value_`. ``clip_grad`` should be a
dict, and the keys could be set as follows:
If the key ``type`` is not set, or ``type`` is "norm",
the accepted keys are as follows:
- max_norm (float or int): Max norm of the gradients.
- norm_type (float or int): Type of the used p-norm. Can be
``'inf'`` for infinity norm.
- error_if_nonfinite (bool): If True, an error is thrown if
the total norm of the gradients from :attr:`parameters` is
``nan``, ``inf``, or ``-inf``. Default: False (will switch
to True in the future)
If the key ``type`` is set to "value", the accepted keys are as
follows:
- clip_value (float or int): maximum allowed value of the
gradients. The gradients are clipped in the range
``(-clip_value, +clip_value)``.
Note:
If ``accumulative_counts`` is larger than 1, perform
......@@ -49,11 +68,18 @@ class OptimWrapper:
``_inner_count += 1`` is automatically performed.
Examples:
>>> # Config sample of OptimWrapper.
>>> # Config sample of OptimWrapper and enable clipping gradient by
>>> # norm.
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> _accumulative_counts=1,
>>> clip_grad=dict(max_norm=0.2))
>>> # Config sample of OptimWrapper and enable clipping gradient by
>>> # value.
>>> optim_wrapper_cfg = dict(
>>> type='OptimWrapper',
>>> _accumulative_counts=1,
>>> clip_grad=dict(type='value', clip_value=0.2))
>>> # Use OptimWrapper to update model.
>>> import torch.nn as nn
>>> import torch
......@@ -105,7 +131,22 @@ class OptimWrapper:
# clip_grad_kwargs should not be non-empty dict.
assert isinstance(clip_grad, dict) and clip_grad, (
'If `clip_grad` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`')
'which is the arguments of `torch.nn.utils.clip_grad_norm_` '
'or clip_grad_value_`.')
clip_type = clip_grad.pop('type', 'norm')
if clip_type == 'norm':
self.clip_func = torch.nn.utils.clip_grad_norm_
self.grad_name = 'grad_norm'
elif clip_type == 'value':
self.clip_func = torch.nn.utils.clip_grad_value_
self.grad_name = 'grad_value'
else:
raise ValueError('type of clip_grad should be "norm" or '
f'"value" but got {clip_type}')
assert clip_grad, ('`clip_grad` should contain other arguments '
'besides `type`. The arguments should match '
'with the `torch.nn.utils.clip_grad_norm_` or '
'clip_grad_value_`')
self.clip_grad_kwargs = clip_grad
# Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance()
......@@ -305,9 +346,11 @@ class OptimWrapper:
params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
grad_norm = clip_grad.clip_grad_norm_(params,
**self.clip_grad_kwargs)
self.message_hub.update_scalar('train/grad_norm', float(grad_norm))
grad = self.clip_func(params, **self.clip_grad_kwargs)
# `torch.nn.utils.clip_grad_value_` will return None.
if grad is not None:
self.message_hub.update_scalar(f'train/{self.grad_name}',
float(grad))
def initialize_count_status(self, model: nn.Module, init_counts: int,
max_counts: int) -> None:
......
......@@ -191,6 +191,7 @@ class TestOptimWrapper(MultiProcessTestCase):
# in the future).
@pytest.mark.skipif(True, reason='Solved in the future')
def test_clip_grads(self):
# Test `clip_grad` with `clip_norm_`
optim_wrapper = OptimWrapper(
self.optimizer, clip_grad=dict(max_norm=35))
loss = self.model(torch.Tensor(1, 1, 1, 1))
......@@ -198,6 +199,15 @@ class TestOptimWrapper(MultiProcessTestCase):
optim_wrapper._clip_grad()
log_scalars = self.message_hub.log_scalars
self.assertIn('train/grad_norm', log_scalars)
self.message_hub._log_scalars.clear()
# Test `clip_grad` with `clip_value_`
optim_wrapper = OptimWrapper(
self.optimizer, clip_grad=dict(type='value', clip_value=0.5))
loss = self.model(torch.Tensor(1, 1, 1, 1))
loss.backward()
optim_wrapper._clip_grad()
self.assertNotIn('train/grad_norm', log_scalars)
def test_state_dict(self):
optim_wrapper = OptimWrapper(self.optimizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment