Skip to content
Snippets Groups Projects
Commit f2b0540f authored by RangiLyu's avatar RangiLyu Committed by Zaida Zhou
Browse files

[Enhance] Raise warning for abnormal momentum (#655)

parent 4a9df3bd
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
...@@ -151,6 +152,13 @@ class ExponentialMovingAverage(BaseAveragedModel): ...@@ -151,6 +152,13 @@ class ExponentialMovingAverage(BaseAveragedModel):
Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t
.. note::
This :attr:`momentum` argument is different from one used in optimizer
classes and the conventional notion of momentum. Mathematically,
:math:`Xema_{t+1}` is the moving average and :math:`X_t` is the
new observed value. The value of momentum is usually a small number,
allowing observed values to slowly update the ema parameters.
Args: Args:
model (nn.Module): The model to be averaged. model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter. momentum (float): The momentum used for updating ema parameter.
...@@ -175,6 +183,12 @@ class ExponentialMovingAverage(BaseAveragedModel): ...@@ -175,6 +183,12 @@ class ExponentialMovingAverage(BaseAveragedModel):
super().__init__(model, interval, device, update_buffers) super().__init__(model, interval, device, update_buffers)
assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\ assert 0.0 < momentum < 1.0, 'momentum must be in range (0.0, 1.0)'\
f'but got {momentum}' f'but got {momentum}'
if momentum > 0.5:
warnings.warn(
'The value of momentum in EMA is usually a small number,'
'which is different from the conventional notion of '
f'momentum but got {momentum}. Please make sure the '
f'value is correct.')
self.momentum = momentum self.momentum = momentum
def avg_func(self, averaged_param: Tensor, source_param: Tensor, def avg_func(self, averaged_param: Tensor, source_param: Tensor,
......
...@@ -93,6 +93,13 @@ class TestAveragedModel(TestCase): ...@@ -93,6 +93,13 @@ class TestAveragedModel(TestCase):
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=3) ExponentialMovingAverage(model, momentum=3)
with self.assertWarnsRegex(
Warning,
'The value of momentum in EMA is usually a small number'):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
ExponentialMovingAverage(model, momentum=0.9)
# test EMA # test EMA
model = torch.nn.Sequential( model = torch.nn.Sequential(
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment