Skip to content
Snippets Groups Projects
Unverified Commit c3aff4fc authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr
parent e2a2b043
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, StepLR) LinearLR, MultiStepLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum, ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum) MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler, from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler, MultiStepParamScheduler, PolyParamScheduler,
_ParamScheduler) StepParamScheduler, _ParamScheduler)
__all__ = [ __all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
...@@ -16,5 +16,6 @@ __all__ = [ ...@@ -16,5 +16,6 @@ __all__ = [
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
] ]
...@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS ...@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler, from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler) MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
...@@ -294,3 +295,49 @@ class StepLR(StepParamScheduler): ...@@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
last_step=last_step, last_step=last_step,
by_epoch=by_epoch, by_epoch=by_epoch,
verbose=verbose) verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyLR(PolyParamScheduler):
"""Decays the learning rate of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
...@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS ...@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler, from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler) MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
...@@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler): ...@@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
last_step=last_step, last_step=last_step,
by_epoch=by_epoch, by_epoch=by_epoch,
verbose=verbose) verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyMomentum(PolyParamScheduler):
"""Decays the momentum of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum momentum at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
...@@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler): ...@@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
Notice that such decay can happen simultaneously with other changes to the Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler. parameter value from outside this scheduler.
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply parameter value in the start_factor (float): The number we multiply parameter value in the
...@@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler): ...@@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
(self.end_factor - self.start_factor))) (self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
] ]
@PARAM_SCHEDULERS.register_module()
class PolyParamScheduler(_ParamScheduler):
"""Decays the parameter value of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Optimizer,
param_name: str,
eta_min: float = 0,
power: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
self.eta_min = eta_min
self.power = power
self.total_iters = end - begin - 1
super().__init__(
optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
def _get_value(self):
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
return [(group[self.param_name] - self.eta_min) *
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
self.eta_min for group in self.optimizer.param_groups]
...@@ -8,7 +8,7 @@ import torch.optim as optim ...@@ -8,7 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR, ExponentialLR, LinearLR, MultiStepLR,
StepLR, _ParamScheduler) PolyLR, StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
...@@ -283,6 +283,21 @@ class TestLRScheduler(TestCase): ...@@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min) scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyLR(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
...@@ -331,6 +346,12 @@ class TestLRScheduler(TestCase): ...@@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3), lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12
......
...@@ -9,8 +9,8 @@ import torch.optim as optim ...@@ -9,8 +9,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum, from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum, ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum, MultiStepMomentum, PolyMomentum,
_ParamScheduler) StepMomentum, _ParamScheduler)
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
...@@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase): ...@@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
self.optimizer, T_max=t, eta_min=eta_min) self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyMomentum(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
...@@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase): ...@@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
self.optimizer, start_factor=0, end_factor=0.3), self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12
......
...@@ -6,12 +6,15 @@ import torch ...@@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler, from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, ExponentialParamScheduler,
LinearParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, MultiStepParamScheduler,
StepParamScheduler, _ParamScheduler) PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
# yapf: enable
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
...@@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase): ...@@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyParamScheduler(
self.optimizer,
param_name='lr',
power=power,
eta_min=min_lr,
end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
...@@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase): ...@@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
end_factor=0.3), end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment