Skip to content
Snippets Groups Projects
Unverified Commit c3aff4fc authored by Tong Gao's avatar Tong Gao Committed by GitHub
Browse files

[Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr
parent e2a2b043
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, StepLR)
LinearLR, MultiStepLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum)
MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler,
_ParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler, _ParamScheduler)
__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
......@@ -16,5 +16,6 @@ __all__ = [
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
]
......@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
......@@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyLR(PolyParamScheduler):
"""Decays the learning rate of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
......@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
......@@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyMomentum(PolyParamScheduler):
"""Decays the momentum of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum momentum at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
......@@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply parameter value in the
......@@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
(self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups
]
@PARAM_SCHEDULERS.register_module()
class PolyParamScheduler(_ParamScheduler):
"""Decays the parameter value of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Optimizer,
param_name: str,
eta_min: float = 0,
power: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
self.eta_min = eta_min
self.power = power
self.total_iters = end - begin - 1
super().__init__(
optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
def _get_value(self):
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
return [(group[self.param_name] - self.eta_min) *
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
self.eta_min for group in self.optimizer.param_groups]
......@@ -8,7 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
StepLR, _ParamScheduler)
PolyLR, StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose
......@@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyLR(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
......@@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12
......
......@@ -9,8 +9,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum,
_ParamScheduler)
MultiStepMomentum, PolyMomentum,
StepMomentum, _ParamScheduler)
from mmengine.testing import assert_allclose
......@@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyMomentum(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
......@@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12
......
......@@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler,
LinearParamScheduler,
MultiStepParamScheduler,
StepParamScheduler, _ParamScheduler)
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
# yapf: enable
from mmengine.testing import assert_allclose
......@@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyParamScheduler(
self.optimizer,
param_name='lr',
power=power,
eta_min=min_lr,
end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
......@@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment