diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ea1d57400ef829d2a71502c9a709a711311f03 --- /dev/null +++ b/mmengine/optim/scheduler/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, + LinearLR, MultiStepLR, StepLR) +from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, + ExponentialMomentum, LinearMomentum, + MultiStepMomentum, StepMomentum) +from .param_scheduler import (ConstantParamScheduler, + CosineAnnealingParamScheduler, + ExponentialParamScheduler, LinearParamScheduler, + MultiStepParamScheduler, StepParamScheduler, + _ParamScheduler) + +__all__ = [ + 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', + 'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum', + 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', + 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', + 'ExponentialParamScheduler', 'LinearParamScheduler', + 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' +] diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..514b8b035c80e8891faffa7a0ef577edfe7edc38 --- /dev/null +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from mmengine.registry import PARAM_SCHEDULERS +from .param_scheduler import (INF, ConstantParamScheduler, + CosineAnnealingParamScheduler, + ExponentialParamScheduler, LinearParamScheduler, + MultiStepParamScheduler, StepParamScheduler) + + +@PARAM_SCHEDULERS.register_module() +class ConstantLR(ConstantParamScheduler): + """Decays the learning rate value of each parameter group by a small + constant factor until the number of epoch reaches a pre-defined milestone: + ``end``. Notice that such decay can happen simultaneously with other + changes to the learning rate value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the + milestone. Defaults to 1./3. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without state + dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + factor: float = 1.0 / 3, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + factor=factor, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingLR(CosineAnnealingParamScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial value and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this + only implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Defaults to 0. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + T_max: int, + eta_min: int = 0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + T_max=T_max, + eta_min=eta_min, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class ExponentialLR(ExponentialParamScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + gamma: float, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + gamma=gamma, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class LinearLR(LinearParamScheduler): + """Decays the learning rate of each parameter group by linearly changing + small multiplicative factor until the number of epoch reaches a pre-defined + milestone: ``end``. + + Notice that such decay can happen simultaneously with other changes to the + learning rate from outside this scheduler. + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the + first epoch. The multiplication factor changes towards end_factor + in the following epochs. Defaults to 1./3. + end_factor (float): The number we multiply learning rate at the end + of linear changing process. Defaults to 1.0. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + start_factor=start_factor, + end_factor=end_factor, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class MultiStepLR(MultiStepParamScheduler): + """Decays the specified learning rate in each parameter group by gamma once + the number of epoch reaches one of the milestones. Notice that such decay + can happen simultaneously with other changes to the learning rate from + outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Defaults to 0.1. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + last_step: int = -1, + begin: int = 0, + end: int = INF, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + milestones=milestones, + gamma=gamma, + last_step=last_step, + begin=begin, + end=end, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class StepLR(StepParamScheduler): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Defaults to 0.1. + begin (int): Step at which to start updating the learning rate. + Defaults to 0. + end (int): Step at which to stop updating the learning rate. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled learning rate is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the learning rate for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + step_size: int, + gamma: float = 0.1, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='lr', + step_size=step_size, + gamma=gamma, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..cc882c3b423df50da7e352ad25e8f357da05138d --- /dev/null +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch + +from mmengine.registry import PARAM_SCHEDULERS +from .param_scheduler import (INF, ConstantParamScheduler, + CosineAnnealingParamScheduler, + ExponentialParamScheduler, LinearParamScheduler, + MultiStepParamScheduler, StepParamScheduler) + + +@PARAM_SCHEDULERS.register_module() +class ConstantMomentum(ConstantParamScheduler): + """Decays the momentum value of each parameter group by a small constant + factor until the number of epoch reaches a pre-defined milestone: ``end``. + Notice that such decay can happen simultaneously with other changes to the + momentum value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply momentum until the milestone. + Defaults to 1./3. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without state + dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by epochs. + Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + factor: float = 1.0 / 3, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + factor=factor, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingMomentum(CosineAnnealingParamScheduler): + r"""Set the momentum of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial value and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + Notice that because the schedule + is defined recursively, the momentum can be simultaneously modified + outside this scheduler by other operators. If the momentum is set + solely by this scheduler, the momentum at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this + only implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum momentum value. Defaults to 0. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + T_max: int, + eta_min: int = 0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + T_max=T_max, + eta_min=eta_min, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class ExponentialMomentum(ExponentialParamScheduler): + """Decays the momentum of each parameter group by gamma every epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of momentum value decay. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + gamma: float, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + gamma=gamma, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class LinearMomentum(LinearParamScheduler): + """Decays the momentum of each parameter group by linearly changing + small multiplicative factor until the number of epoch reaches a pre-defined + milestone: ``end``. + + Notice that such decay can happen simultaneously with other changes to the + momentum from outside this scheduler. + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply momentum in the + first epoch. The multiplication factor changes towards end_factor + in the following epochs. Defaults to 1./3. + end_factor (float): The number we multiply momentum at the end + of linear changing process. Defaults to 1.0. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + start_factor=start_factor, + end_factor=end_factor, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class MultiStepMomentum(MultiStepParamScheduler): + """Decays the specified momentum in each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the momentum from outside this + scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of momentum value decay. + Defaults to 0.1. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + last_step: int = -1, + begin: int = 0, + end: int = INF, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + milestones=milestones, + gamma=gamma, + last_step=last_step, + begin=begin, + end=end, + by_epoch=by_epoch, + verbose=verbose) + + +@PARAM_SCHEDULERS.register_module() +class StepMomentum(StepParamScheduler): + """Decays the momentum of each parameter group by gamma every step_size + epochs. Notice that such decay can happen simultaneously with other changes + to the momentum from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of momentum value decay. + gamma (float): Multiplicative factor of momentum value decay. + Defaults to 0.1. + begin (int): Step at which to start updating the momentum. + Defaults to 0. + end (int): Step at which to stop updating the momentum. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled momentum is updated by + epochs. Defaults to True. + verbose (bool): Whether to print the momentum for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + step_size: int, + gamma: float = 0.1, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + super().__init__( + optimizer, + param_name='momentum', + step_size=step_size, + gamma=gamma, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..bbec0556b0db7a8739656473ebca39dd66f34261 --- /dev/null +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -0,0 +1,600 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings +import weakref +from collections import Counter +from functools import wraps +from typing import Callable, List + +from torch.optim import Optimizer + +from mmengine.registry import PARAM_SCHEDULERS + +INF = int(1e9) + + +class _ParamScheduler: + """Base class for parameter schedulers. + + It should be inherited by all schedulers that schedule parameters in the + optimizer's ``param_groups``. All subclasses should overwrite the + ``_get_value()`` according to their own schedule strategy. + The implementation is motivated by + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. + + Args: + optimizer (Optimizer): Wrapped optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resuming without + state dict. Default value ``-1`` means the ``step`` function is + never be called before. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ # noqa: E501 + + def __init__(self, + optimizer: Optimizer, + param_name: str, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('``optimizer`` should be an Optimizer,' + 'but got {}'.format(type(optimizer).__name__)) + self.optimizer = optimizer + self.param_name = param_name + + if end <= begin: + raise ValueError('end should be larger than begin, but got' + ' begin={}, end={}'.format(begin, end)) + self.begin = begin + self.end = end + + self.by_epoch = by_epoch + + assert isinstance(last_step, int) and last_step >= -1 + # Initialize valid step count and base values + if last_step == -1: + for group in optimizer.param_groups: + # If the param is never be scheduled, record the current value + # as the initial value. + group.setdefault(f'initial_{param_name}', group[param_name]) + else: + for i, group in enumerate(optimizer.param_groups): + if f'initial_{param_name}' not in group: + raise KeyError( + f"param 'initial_{param_name}' is not specified " + 'in param_groups[{}] when resuming an optimizer'. + format(i)) + self.base_values = [ + group[f'initial_{param_name}'] for group in optimizer.param_groups + ] + self.last_step = last_step + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method: Callable): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) # type: ignore + # Get the unbound method for the same purpose. + func = method.__func__ # type: ignore + cls = instance_ref().__class__ # type: ignore + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._global_step += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore + return wrapper + + # add counter to optimizer + self.optimizer.step = with_counter(self.optimizer.step) + self.optimizer._global_step = -1 + + self._global_step = -1 + self.verbose = verbose + + self.step() + + def state_dict(self) -> dict: + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which is not + the optimizer. + + Returns: + dict: scheduler state. + """ + return { + key: value + for key, value in self.__dict__.items() if key != 'optimizer' + } + + def load_state_dict(self, state_dict: dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_value(self): + """Return the last computed value by current scheduler. + + Returns: + list: A list of the last computed value of the optimizer's + ``param_group``. + """ + return self._last_value + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + raise NotImplementedError + + def print_value(self, is_verbose: bool, group: int, value: float): + """Display the current parameter value. + + Args: + is_verbose (bool): Whether to print the value. + group (int): The index of the current ``param_group``. + value (float): The parameter value. + """ + if is_verbose: + print('Adjusting parameter value' + ' of group {} to {:.4e}.'.format(group, value)) + + def step(self): + """Adjusts the parameter value of each parameter group based on the + specified schedule.""" + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._global_step == 0: + if not hasattr(self.optimizer.step, '_with_counter'): + warnings.warn( + 'Seems like `optimizer.step()` has been overridden after' + 'parameter value scheduler initialization. Please, make' + 'sure to call `optimizer.step()` before' + '`scheduler.step()`. See more details at' + 'https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 + UserWarning) + + # Just check if there were two first scheduler.step() calls + # before optimizer.step() + elif self.optimizer._global_step < 0: + warnings.warn( + 'Detected call of `scheduler.step()` before' + '`optimizer.step()`. In PyTorch 1.1.0 and later, you' + 'should call them in the opposite order: ' + '`optimizer.step()` before `scheduler.step()`. ' + 'Failure to do this will result in PyTorch skipping ' + 'the first value of the parameter value schedule. ' + 'See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 + UserWarning) + self._global_step += 1 + + # Compute parameter value per param group in the effective range + if self.begin <= self._global_step < self.end: + self.last_step += 1 + values = self._get_value() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, value = data + param_group[self.param_name] = value + self.print_value(self.verbose, i, value) + + self._last_value = [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + +@PARAM_SCHEDULERS.register_module() +class StepParamScheduler(_ParamScheduler): + """Decays the parameter value of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the parameter value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of parameter value decay. + gamma (float): Multiplicative factor of parameter value decay. + Defaults to 0.1. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + step_size: int, + gamma: float = 0.1, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + self.step_size = step_size + self.gamma = gamma + super().__init__( + optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + if (self.last_step == 0) or (self.last_step % self.step_size != 0): + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + return [ + group[self.param_name] * self.gamma + for group in self.optimizer.param_groups + ] + + +@PARAM_SCHEDULERS.register_module() +class MultiStepParamScheduler(_ParamScheduler): + """Decays the specified parameter in each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the parameter from outside this + scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of parameter value decay. + Defaults to 0.1. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + milestones: List[int], + gamma: float = 0.1, + last_step: int = -1, + begin: int = 0, + end: int = INF, + by_epoch: bool = True, + verbose: bool = False): + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + if self.last_step not in self.milestones: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + return [ + group[self.param_name] * + self.gamma**self.milestones[self.last_step] + for group in self.optimizer.param_groups + ] + + +@PARAM_SCHEDULERS.register_module() +class ConstantParamScheduler(_ParamScheduler): + """Decays the parameter value of each parameter group by a small constant + factor until the number of epoch reaches a pre-defined milestone: ``end``. + Notice that such decay can happen simultaneously with other changes to the + parameter value from outside this scheduler. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply parameter value until the + milestone. Defaults to 1./3. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + factor: float = 1.0 / 3, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + if factor > 1.0 or factor < 0: + raise ValueError( + 'Constant multiplicative factor should between 0 and 1.') + + self.factor = factor + self.total_iters = end - begin - 1 + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + if self.last_step == 0: + return [ + group[self.param_name] * self.factor + for group in self.optimizer.param_groups + ] + + if (self.last_step > self.total_iters + or (self.last_step != self.total_iters)): + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + if self.last_step == self.total_iters: + return [ + group[self.param_name] * (1.0 / self.factor) + for group in self.optimizer.param_groups + ] + + +@PARAM_SCHEDULERS.register_module() +class ExponentialParamScheduler(_ParamScheduler): + """Decays the parameter value of each parameter group by gamma every epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of parameter value decay. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + gamma: float, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + self.gamma = gamma + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + return [ + group[self.param_name] * self.gamma + for group in self.optimizer.param_groups + ] + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingParamScheduler(_ParamScheduler): + r"""Set the parameter value of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial value and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + Notice that because the schedule + is defined recursively, the parameter value can be simultaneously modified + outside this scheduler by other operators. If the parameter value is set + solely by this scheduler, the parameter value at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this + only implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum parameter value. Defaults to 0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + T_max: int, + eta_min: float = 0., + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + self.T_max = T_max + self.eta_min = eta_min + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group[self.param_name] + (base_value - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2 + for base_value, group in zip(self.base_values, + self.optimizer.param_groups) + ] + return [(1 + math.cos(math.pi * self.last_step / self.T_max)) / + (1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) * + (group[self.param_name] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + +@PARAM_SCHEDULERS.register_module() +class LinearParamScheduler(_ParamScheduler): + """Decays the parameter value of each parameter group by linearly changing + small multiplicative factor until the number of epoch reaches a pre-defined + milestone: ``end``. + + Notice that such decay can happen simultaneously with other changes to the + parameter value from outside this scheduler. + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply parameter value in the + first epoch. The multiplication factor changes towards end_factor + in the following epochs. Defaults to 1./3. + end_factor (float): The number we multiply parameter value at the end + of linear changing process. Defaults to 1.0. + begin (int): Step at which to start updating the parameters. + Defaults to 0. + end (int): Step at which to stop updating the parameters. + Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def __init__(self, + optimizer: Optimizer, + param_name: str, + start_factor: float = 1.0 / 3, + end_factor: float = 1.0, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + if start_factor > 1.0 or start_factor < 0: + raise ValueError( + 'Starting multiplicative factor should between 0 and 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + 'Ending multiplicative factor should between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = end - begin - 1 + super().__init__( + optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) + + def _get_value(self): + + if self.last_step == 0: + return [ + group[self.param_name] * self.start_factor + for group in self.optimizer.param_groups + ] + + return [ + group[self.param_name] * + (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_step - 1) * + (self.end_factor - self.start_factor))) + for group in self.optimizer.param_groups + ] diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index d7f8ee304ca8e3b856d7bdf262910e7c2ab089c9..d8602d4f5f38b4863c44bab4819e814c98363f69 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from .registry import Registry, build_from_cfg from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS, - OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, RUNNER_CONSTRUCTORS, - RUNNERS, TASK_UTILS, TRANSFORMS, WEIGHT_INITIALIZERS) + OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, + RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, + WEIGHT_INITIALIZERS) __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', - 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS' + 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 71ff9dd6fa514ce640d437527fdc743f820fe9fa..c67a8d5b67f294aad22e915f7d8d171e750fa866 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -29,6 +29,8 @@ WEIGHT_INITIALIZERS = Registry('weight initializer') OPTIMIZERS = Registry('optimizer') # manage constructors that customize the optimization hyperparameters. OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor') +# mangage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry('parameter scheduler') # manage task-specific modules like anchor generators and box coders TASK_UTILS = Registry('task util') diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 56e07bba48c8191da64d0836d80054097907478c..8d53990b1c7abc581a990016fd78673c6a6200c8 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import math from unittest import TestCase @@ -39,7 +40,7 @@ class TestLRScheduler(TestCase): _ParamScheduler(self.optimizer, param_name='lr') def test_invalid_optimizer(self): - with self.assertRaisesRegex(TypeError, 'is not an Optimizer'): + with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): StepLR('invalid_optimizer', step_size=1) def test_overwrite_optimzer_step(self): diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index c2412e1009300ee75806e0eff521cdd4a5fc27cf..9b144a0acbe60f839347b80d4ba57bb36e1fdadc 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import math from unittest import TestCase @@ -37,7 +38,7 @@ class TestMomentumScheduler(TestCase): self.model.parameters(), lr=0.01, momentum=0.05, weight_decay=5e-4) def test_invalid_optimizer(self): - with self.assertRaisesRegex(TypeError, 'is not an Optimizer'): + with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): StepMomentum('invalid_optimizer', step_size=1) def test_overwrite_optimzer_step(self): diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index cc4af0fb4cd3778208ede7ec3af7b48be9af1497..8303d4ce969892a777bfd3ccc0ec0d71d4db0040 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import math from unittest import TestCase @@ -42,7 +43,7 @@ class TestParameterScheduler(TestCase): _ParamScheduler(self.optimizer, param_name='lr') def test_invalid_optimizer(self): - with self.assertRaisesRegex(TypeError, 'is not an Optimizer'): + with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): StepParamScheduler('invalid_optimizer', 'lr', step_size=1) def test_overwrite_optimzer_step(self):