diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py index df54f738ae3b5134b10d9809fc64dcfa4c9bb33e..c12d1c2970d39561e3ddec2e92e49f67bbfdc4ba 100644 --- a/mmengine/optim/scheduler/__init__.py +++ b/mmengine/optim/scheduler/__init__.py @@ -1,16 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, - LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR) +# yapf: disable +from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR, + ExponentialLR, LinearLR, MultiStepLR, OneCycleLR, + PolyLR, StepLR) from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, - ExponentialMomentum, LinearMomentum, - MultiStepMomentum, PolyMomentum, StepMomentum) + CosineRestartMomentum, ExponentialMomentum, + LinearMomentum, MultiStepMomentum, + PolyMomentum, StepMomentum) from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, + CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, OneCycleParamScheduler, PolyParamScheduler, StepParamScheduler, _ParamScheduler) +# yapf: enable + __all__ = [ 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum', @@ -19,5 +25,6 @@ __all__ = [ 'ExponentialParamScheduler', 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler', 'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler', - 'OneCycleLR' + 'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR', + 'CosineRestartMomentum' ] diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 21ae823e04a4a90eb5aa22c4d385b04b92ab81ed..cbb82dfc550ae933e5320a7db710ffca7fb72b3c 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -2,6 +2,7 @@ from mmengine.registry import PARAM_SCHEDULERS from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, + CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, OneCycleParamScheduler, PolyParamScheduler, StepParamScheduler) @@ -277,3 +278,35 @@ class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: https://arxiv.org/abs/1708.07120 """# noqa E501 + + +@PARAM_SCHEDULERS.register_module() +class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler): + """Sets the learning rate of each parameter group according to the cosine + annealing with restarts scheme. The cosine restart policy anneals the + learning rate from the initial value to `eta_min` with a cosine annealing + schedule and then restarts another period from the maximum value multiplied + with `restart_weight`. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float]): Restart weights at each + restart iteration. Defaults to [1]. + eta_min (float): Minimum parameter value at the end of scheduling. + Defaults to None. + eta_min_ratio (float, optional): The ratio of minimum parameter value + to the base parameter value. Either `min_lr` or `min_lr_ratio` + should be specified. Default: None. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index 2cbe9d707914d7b53861c7bebbfdb2c1d62e74dc..9cb5f9e914cecd29ac595585c3672ddaedc25a6a 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -2,6 +2,7 @@ from mmengine.registry import PARAM_SCHEDULERS from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, + CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, PolyParamScheduler, StepParamScheduler) @@ -243,3 +244,36 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler): verbose (bool): Whether to print the value for each update. Defaults to False. """ + + +@PARAM_SCHEDULERS.register_module() +class CosineRestartMomentum(MomentumSchedulerMixin, + CosineRestartParamScheduler): + """Sets the momentum of each parameter group according to the cosine + annealing with restarts scheme. The cosine restart policy anneals the + momentum from the initial value to `eta_min` with a cosine annealing + schedule and then restarts another period from the maximum value multiplied + with `restart_weight`. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float]): Restart weights at each + restart iteration. Defaults to [1]. + eta_min (float): Minimum parameter value at the end of scheduling. + Defaults to None. + eta_min_ratio (float, optional): The ratio of minimum parameter value + to the base parameter value. Either `min_lr` or `min_lr_ratio` + should be specified. Default: None. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 50ea25a12e058c0ffa8051bee326d8b2280912a3..519e7b35b66a5cbc771482d70f0cd09252f7c223 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -9,7 +9,7 @@ import warnings import weakref from collections import Counter from functools import wraps -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Sequence, Union from torch.optim import Optimizer @@ -227,6 +227,8 @@ class StepParamScheduler(_ParamScheduler): Args: optimizer (OptimWrapper or Optimizer): Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. step_size (int): Period of parameter value decay. gamma (float): Multiplicative factor of parameter value decay. Defaults to 0.1. @@ -313,6 +315,8 @@ class MultiStepParamScheduler(_ParamScheduler): Args: optimizer (OptimWrapper or Optimizer): Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of parameter value decay. Defaults to 0.1. @@ -401,6 +405,8 @@ class ConstantParamScheduler(_ParamScheduler): Args: optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. factor (float): The number we multiply parameter value until the milestone. Defaults to 1./3. begin (int): Step at which to start updating the parameters. @@ -488,6 +494,8 @@ class ExponentialParamScheduler(_ParamScheduler): Args: optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. gamma (float): Multiplicative factor of parameter value decay. begin (int): Step at which to start updating the parameters. Defaults to 0. @@ -585,6 +593,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler): Args: optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. T_max (int, optional): Maximum number of iterations. If not specified, use ``end - begin``. Defaults to None. eta_min (float): Minimum parameter value. Defaults to 0. @@ -684,6 +694,8 @@ class LinearParamScheduler(_ParamScheduler): Args: optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. start_factor (float): The number we multiply parameter value in the first epoch. The multiplication factor changes towards end_factor in the following epochs. Defaults to 1./3. @@ -780,6 +792,8 @@ class PolyParamScheduler(_ParamScheduler): Args: optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. eta_min (float): Minimum parameter value at the end of scheduling. Defaults to 0. power (float): The power of the polynomial. Defaults to 1.0. @@ -882,6 +896,8 @@ class OneCycleParamScheduler(_ParamScheduler): Args: optimizer (Optimizer): Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. eta_max (float or list): Upper parameter value boundaries in the cycle for each parameter group. total_steps (int): The total number of steps in the cycle. Note that @@ -1094,3 +1110,159 @@ class OneCycleParamScheduler(_ParamScheduler): params.append(computed_param) return params + + +@PARAM_SCHEDULERS.register_module() +class CosineRestartParamScheduler(_ParamScheduler): + """Sets the parameters of each parameter group according to the cosine + annealing with restarts scheme. The cosine restart policy anneals the + parameter from the initial value to `eta_min` with a cosine annealing + schedule and then restarts another period from the maximum value multiplied + with `restart_weight`. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. + periods (list[int]): Periods for each cosine anneling cycle. + restart_weights (list[float]): Restart weights at each + restart iteration. Defaults to [1]. + eta_min (float): Minimum parameter value at the end of scheduling. + Defaults to None. + eta_min_ratio (float, optional): The ratio of minimum parameter value + to the base parameter value. Either `min_lr` or `min_lr_ratio` + should be specified. Default: None. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Union[Optimizer, OptimWrapper], + param_name: str, + periods: List[int], + restart_weights: Sequence[float] = (1, ), + eta_min: Optional[float] = None, + eta_min_ratio: Optional[float] = None, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + assert (eta_min is None) ^ (eta_min_ratio is None) + self.periods = periods + self.eta_min = eta_min + self.eta_min_ratio = eta_min_ratio + self.restart_weights = restart_weights + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_periods = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + @classmethod + def build_iter_from_epoch(cls, + *args, + periods, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + periods = [p * epoch_length for p in periods] + by_epoch = False + begin = int(begin * epoch_length) + if end != INF: + end = int(end * epoch_length) + return cls( + *args, + periods=periods, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + idx = self.get_position_from_periods(self.last_step, + self.cumulative_periods) + # if current step is not in the periods, return origin parameters + if idx is None: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] + current_periods = self.periods[idx] + step = self.last_step - nearest_restart + values = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + eta_max = base_value * current_weight + if self.eta_min_ratio is None: + eta_min = self.eta_min * (1 - current_weight) + else: + eta_min = base_value * self.eta_min_ratio * (1 - + current_weight) + if step == 0: + values.append(eta_max) + + elif (step - 1 - current_periods) % (2 * current_periods) == 0: + values.append(group[self.param_name] + (eta_max - eta_min) * + (1 - math.cos(math.pi / current_periods)) / 2) + else: + values.append( + (1 + math.cos(math.pi * step / current_periods)) / + (1 + math.cos(math.pi * (step - 1) / current_periods)) * + (group[self.param_name] - eta_min) + eta_min) + + return values + + @staticmethod + def get_position_from_periods( + iteration: int, cumulative_periods: List[int]) -> Optional[int]: + """Get the position from a period list. + + It will return the index of the right-closest number in the period + list. + For example, the cumulative_periods = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 3. + + Args: + iteration (int): Current iteration. + cumulative_periods (list[int]): Cumulative period list. + + Returns: + Optional[int]: The position of the right-closest number in the + period list. If not in the period, return None. + """ + for i, period in enumerate(cumulative_periods): + if iteration < period: + return i + return None diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 87f83479537cb16fb3b2f5d2b10c4a9c3a97a539..bd537b8680e2d0e4a8f72ddfc43f680cd5dbc0e1 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -7,8 +7,8 @@ import torch.nn.functional as F import torch.optim as optim from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, - ExponentialLR, LinearLR, MultiStepLR, - OneCycleLR, PolyLR, StepLR, + CosineRestartLR, ExponentialLR, LinearLR, + MultiStepLR, OneCycleLR, PolyLR, StepLR, _ParamScheduler) from mmengine.testing import assert_allclose @@ -333,6 +333,34 @@ class TestLRScheduler(TestCase): self.optimizer, power=power, eta_min=min_lr, end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) + def test_cosine_restart_scheduler(self): + with self.assertRaises(AssertionError): + CosineRestartLR( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) + with self.assertRaises(AssertionError): + CosineRestartLR( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) + single_targets = [ + 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, + 0.0086372, 0.0023872, 0.0023872 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + scheduler = CosineRestartLR( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -387,6 +415,20 @@ class TestLRScheduler(TestCase): lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002), epochs=10) + def test_cosine_restart_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: CosineRestartLR( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartLR( + self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), + epochs=10) + def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 7014205e0b0238d4fe3e7d7ce1e6b1a4669bf636..d98f2fe424a2f915eb9c200b6b0d423e67bab15c 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -8,6 +8,7 @@ import torch.optim as optim from mmengine.optim.scheduler import (ConstantMomentum, CosineAnnealingMomentum, + CosineRestartMomentum, ExponentialMomentum, LinearMomentum, MultiStepMomentum, PolyMomentum, StepMomentum, _ParamScheduler) @@ -399,6 +400,43 @@ class TestMomentumScheduler(TestCase): self._test_scheduler_value( self.optimizer_with_betas, scheduler, targets, epochs=10) + def test_cosine_restart_scheduler(self): + with self.assertRaises(AssertionError): + CosineRestartMomentum( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) + with self.assertRaises(AssertionError): + CosineRestartMomentum( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) + single_targets = [ + 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, + 0.0086372, 0.0023872, 0.0023872 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + scheduler = CosineRestartMomentum( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value( + self.optimizer, scheduler, targets, epochs=10) + + scheduler = CosineRestartMomentum( + self.optimizer_with_betas, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value( + self.optimizer_with_betas, scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -454,6 +492,20 @@ class TestMomentumScheduler(TestCase): lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002), epochs=10) + def test_cosine_restart_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: CosineRestartMomentum( + self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartMomentum( + self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), + epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index 753bacb1281aebcdcd32dc5ded4dd7b8ba219500..e27f0dd7b1c77e0e758666d7024d739787ab8f38 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -12,12 +12,13 @@ from mmengine.optim import OptimWrapper # yapf: disable from mmengine.optim.scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, + CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, + OneCycleParamScheduler, PolyParamScheduler, StepParamScheduler, _ParamScheduler) -from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler # yapf: enable from mmengine.testing import assert_allclose @@ -406,6 +407,37 @@ class TestParameterScheduler(TestCase): end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) + def test_cosine_restart_scheduler(self): + with self.assertRaises(AssertionError): + CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) + with self.assertRaises(AssertionError): + CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) + single_targets = [ + 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, + 0.0086372, 0.0023872, 0.0023872 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + scheduler = CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs): @@ -483,6 +515,22 @@ class TestParameterScheduler(TestCase): self.optimizer, param_name='lr', power=0.8, eta_min=0.002), epochs=10) + def test_cosine_restart_scheduler_state_dict(self): + self._check_scheduler_state_dict( + lambda: CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), + epochs=10) + def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError):