diff --git a/docs/en/api/optim.rst b/docs/en/api/optim.rst index 634884c9f522d437fbe8e50c2515d9218848c8b7..3de2ade6a2e3b1c7aa7c74a6629991c79c052f92 100644 --- a/docs/en/api/optim.rst +++ b/docs/en/api/optim.rst @@ -63,3 +63,6 @@ Scheduler StepLR StepMomentum StepParamScheduler + ReduceOnPlateauLR + ReduceOnPlateauMomentum + ReduceOnPlateauParamScheduler diff --git a/docs/zh_cn/api/optim.rst b/docs/zh_cn/api/optim.rst index 634884c9f522d437fbe8e50c2515d9218848c8b7..3de2ade6a2e3b1c7aa7c74a6629991c79c052f92 100644 --- a/docs/zh_cn/api/optim.rst +++ b/docs/zh_cn/api/optim.rst @@ -63,3 +63,6 @@ Scheduler StepLR StepMomentum StepParamScheduler + ReduceOnPlateauLR + ReduceOnPlateauMomentum + ReduceOnPlateauParamScheduler diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index cb033ce4a56f524d2eab41fbaeff077968fd7891..3b2f1e610a9db058e99d03faf607514d73bce030 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Union +from typing import Dict, Optional, Union +from mmengine.optim import _ParamScheduler from mmengine.registry import HOOKS +from mmengine.utils import is_list_of from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -19,7 +21,7 @@ class ParamSchedulerHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: - """Call step function for each scheduler after each iteration. + """Call step function for each scheduler after each training iteration. Args: runner (Runner): The runner of the training process. @@ -32,15 +34,15 @@ class ParamSchedulerHook(Hook): keep ``data_batch`` here. """ + if runner.param_schedulers is None: + return + def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if not scheduler.by_epoch: scheduler.step() - if runner.param_schedulers is None: - return - if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): @@ -53,21 +55,67 @@ class ParamSchedulerHook(Hook): f'but got {runner.param_schedulers}') def after_train_epoch(self, runner) -> None: - """Call step function for each scheduler after each epoch. + """Call step function for each scheduler after each training epoch. Args: runner (Runner): The runner of the training process. """ + if runner.param_schedulers is None: + return + def step(param_schedulers): assert isinstance(param_schedulers, list) for scheduler in param_schedulers: if scheduler.by_epoch: scheduler.step() + if isinstance(runner.param_schedulers, list): + step(runner.param_schedulers) + elif isinstance(runner.param_schedulers, dict): + for param_schedulers in runner.param_schedulers.values(): + step(param_schedulers) + else: + raise TypeError( + 'runner.param_schedulers should be list of ParamScheduler or ' + 'a dict containing list of ParamScheduler, ' + f'but got {runner.param_schedulers}') + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """Call step function for each scheduler which has attribute + ``need_val_args`` after each validation epoch. + + Args: + runner (Runner): The runner of the validation process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + + Note: + if ``runner.param_schedulers`` is not built before, + the hook ``after_val_epoch`` will be skipped. + """ + if runner.param_schedulers is None: return + # avoid counting scheduler._global_step + # it has counted in after_train_* hook + if metrics is None: + return + + def step(param_schedulers): + # check param_schedulers is list and built + if not is_list_of(param_schedulers, _ParamScheduler): + return + + for scheduler in param_schedulers: + if (scheduler.by_epoch + and getattr(scheduler, 'need_val_args', False)): + scheduler.step(metrics) + if isinstance(runner.param_schedulers, list): step(runner.param_schedulers) elif isinstance(runner.param_schedulers, dict): diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py index 9b34441ef45b4de0a618634cc2d5cc73f1b038f2..72118b179f3f07e2ee4fa0f1cceb1e1a3c394332 100644 --- a/mmengine/optim/__init__.py +++ b/mmengine/optim/__init__.py @@ -11,8 +11,10 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler, MultiStepLR, MultiStepMomentum, MultiStepParamScheduler, OneCycleLR, OneCycleParamScheduler, PolyLR, PolyMomentum, - PolyParamScheduler, StepLR, StepMomentum, - StepParamScheduler, _ParamScheduler) + PolyParamScheduler, ReduceOnPlateauLR, + ReduceOnPlateauMomentum, ReduceOnPlateauParamScheduler, + StepLR, StepMomentum, StepParamScheduler, + _ParamScheduler) # yapf: enable __all__ = [ @@ -25,5 +27,6 @@ __all__ = [ 'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict', 'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum', - 'PolyParamScheduler' + 'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum', + 'ReduceOnPlateauParamScheduler' ] diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py index c12d1c2970d39561e3ddec2e92e49f67bbfdc4ba..48ccc34bc41b07442e2494b03a303b3c0054b42b 100644 --- a/mmengine/optim/scheduler/__init__.py +++ b/mmengine/optim/scheduler/__init__.py @@ -2,21 +2,22 @@ # yapf: disable from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR, ExponentialLR, LinearLR, MultiStepLR, OneCycleLR, - PolyLR, StepLR) + PolyLR, ReduceOnPlateauLR, StepLR) from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, CosineRestartMomentum, ExponentialMomentum, LinearMomentum, MultiStepMomentum, - PolyMomentum, StepMomentum) + PolyMomentum, ReduceOnPlateauMomentum, + StepMomentum) from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, StepParamScheduler, - _ParamScheduler) + PolyParamScheduler, + ReduceOnPlateauParamScheduler, + StepParamScheduler, _ParamScheduler) # yapf: enable - __all__ = [ 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum', @@ -26,5 +27,6 @@ __all__ = [ 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler', 'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler', 'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR', - 'CosineRestartMomentum' + 'CosineRestartMomentum', 'ReduceOnPlateauParamScheduler', + 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum' ] diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index b1eeabec5f794e8130158ac710907f33827c8f77..08b98ee76fb444c0ed648dd07f5a8ce3cb7837b3 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS +# yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, StepParamScheduler) + PolyParamScheduler, + ReduceOnPlateauParamScheduler, + StepParamScheduler) + +# yapf: enable class LRSchedulerMixin: @@ -314,3 +319,60 @@ class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler): verbose (bool): Whether to print the value for each update. Defaults to False. """ + + +@PARAM_SCHEDULERS.register_module() +class ReduceOnPlateauLR(LRSchedulerMixin, ReduceOnPlateauParamScheduler): + """Reduce the learning rate of each parameter group when a metric has + stopped improving. Models often benefit from reducing the learning rate by + a factor of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a ``patience`` number of epochs, + the learning rate is reduced. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + monitor (str): Key name of the value to monitor in metrics dict. + rule (str): One of `less`, `greater`. In `less` rule, learning rate + will be reduced when the quantity monitored has stopped + decreasing; in `greater` rule it will be reduced when the + quantity monitored has stopped increasing. Defaults to 'less'. + The ``rule`` is the renaming of ``mode`` in pytorch. + factor (float): Factor by which the learning rate will be + reduced. new_param = param * factor. Defaults to 0.1. + patience (int): Number of epochs with no improvement after + which learning rate will be reduced. For example, if + ``patience = 2``, then we will ignore the first 2 epochs + with no improvement, and will only decrease the learning rate after + the 3rd epoch if the monitor value still hasn't improved then. + Defaults to 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Defaults to 1e-4. + threshold_rule (str): One of `rel`, `abs`. In `rel` rule, + dynamic_threshold = best * ( 1 + threshold ) in 'greater' + rule or best * ( 1 - threshold ) in `less` rule. + In `abs` rule, dynamic_threshold = best + threshold in + `greater` rule or best - threshold in `less` rule. + Defaults to 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after learning rate has been reduced. + Defaults to 0. + min_value (float or list[float]): A scalar or a sequence of scalars. + A lower bound on the learning rate of each parameter group + respectively. Defaults to 0. . + eps (float): Minimal decay applied to learning rate. If the difference + between new and old learning rate is smaller than eps, the update + is ignored. Defaults to 1e-8. + begin (int): Step at which to start triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to 0. + end (int): Step at which to stop triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index 102b173146f525a84ad5892d5e544b6ea44c2248..b15c6c914f100e89ee158df650f938cef5df57d3 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -1,12 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS +# yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, CosineRestartParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, PolyParamScheduler, + ReduceOnPlateauParamScheduler, StepParamScheduler) +# yapf: enable + class MomentumSchedulerMixin: """A mixin class for momentum schedulers. @@ -32,8 +36,8 @@ class MomentumSchedulerMixin: super().__init__(optimizer, param_name, *args, **kwargs) def step(self): - """Adjusts the parameter value of each parameter group based on the - specified schedule.""" + """Adjusts the momentum of each parameter group based on the specified + schedule.""" super().step() if self.use_betas: for group in self.optimizer.param_groups: @@ -281,3 +285,77 @@ class CosineRestartMomentum(MomentumSchedulerMixin, verbose (bool): Whether to print the value for each update. Defaults to False. """ + + +@PARAM_SCHEDULERS.register_module() +class ReduceOnPlateauMomentum(MomentumSchedulerMixin, + ReduceOnPlateauParamScheduler): + """Reduce the momentum of each parameter group when a metric has stopped + improving. Models often benefit from reducing the momentum by a factor of + 2-10 once learning stagnates. This scheduler reads a metrics quantity and + if no improvement is seen for a ``patience`` number of epochs, the momentum + is reduced. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + monitor (str): Key name of the value to monitor in metrics dict. + rule (str): One of `less`, `greater`. In `less` rule, momentum will + be reduced when the quantity monitored has stopped + decreasing; in `greater` rule it will be reduced when the + quantity monitored has stopped increasing. Defaults to 'less'. + The ``rule`` is the renaming of ``mode`` in pytorch. + factor (float): Factor by which the momentum will be + reduced. new_param = param * factor. Defaults to 0.1. + patience (int): Number of epochs with no improvement after + which momentum will be reduced. For example, if + ``patience = 2``, then we will ignore the first 2 epochs + with no improvement, and will only decrease the momentum after + the 3rd epoch if the monitor value still hasn't improved then. + Defaults to 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Defaults to 1e-4. + threshold_rule (str): One of `rel`, `abs`. In `rel` rule, + dynamic_threshold = best * ( 1 + threshold ) in 'greater' + rule or best * ( 1 - threshold ) in `less` rule. + In `abs` rule, dynamic_threshold = best + threshold in + `greater` rule or best - threshold in `less` rule. + Defaults to 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after momentum has been reduced. Defaults to 0. + min_value (float or list[float]): A scalar or a sequence of scalars. + A lower bound on the momentum of each parameter group + respectively. Defaults to 0. . + eps (float): Minimal decay applied to momentum. If the difference + between new and old momentum is smaller than eps, the update is + ignored. Defaults to 1e-8. + begin (int): Step at which to start triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to 0. + end (int): Step at which to stop triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + def step(self, metrics=None): + """Adjusts the momentum of each parameter group based on the specified + schedule. + + Args: + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + Defaults to None. + """ + super(MomentumSchedulerMixin, self).step(metrics) + if self.use_betas: + for group in self.optimizer.param_groups: + _, beta_1 = group['betas'] + # update the betas with the calculated value + group['betas'] = (group['momentum'], beta_1) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 7fdf5c3d8874ad107703fe2840e09f0fee08fbe7..1bd15e9c4fe61385048da74b7493c6c28a1362f4 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -951,7 +951,7 @@ class OneCycleParamScheduler(_ParamScheduler): .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: https://arxiv.org/abs/1708.07120 - """# noqa E501 + """ # noqa E501 def __init__(self, optimizer: Union[Optimizer, OptimWrapper], @@ -1058,7 +1058,7 @@ class OneCycleParamScheduler(_ParamScheduler): if len(param) != len(optimizer.param_groups): raise ValueError( f'expected {len(optimizer.param_groups)} values ' - f'for {name}, got { len(param)}') + f'for {name}, got {len(param)}') return param else: return [param] * len(optimizer.param_groups) @@ -1283,3 +1283,278 @@ class CosineRestartParamScheduler(_ParamScheduler): if iteration < period: return i return None + + +@PARAM_SCHEDULERS.register_module() +class ReduceOnPlateauParamScheduler(_ParamScheduler): + """Reduce the parameters of each parameter group when a metric has stopped + improving. Models often benefit from reducing the parameters by a factor of + 2-10 once learning stagnates. This scheduler reads a metrics quantity and + if no improvement is seen for a ``patience`` number of epochs, the + parameters are reduced. + + The implementation is motivated by + `PyTorch + ReduceLROnPlateau<https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py>`_. + + Args: + optimizer (Optimizer or OptimWrapper): optimizer or Wrapped + optimizer. + param_name (str): Name of the parameter to be adjusted, such as + ``lr``, ``momentum``. + monitor (str): The name of the metric to measure whether + the performance of the model is improved. + rule (str): One of `less`, `greater`. In `less` rule, parameters will + be reduced when the quantity monitored has stopped + decreasing; in `greater` rule it will be reduced when the + quantity monitored has stopped increasing. Defaults to 'less'. + The ``rule`` is the renaming of ``mode`` in pytorch. + factor (float): Factor by which the parameters will be + reduced. new_param = param * factor. Defaults to 0.1. + patience (int): Number of epochs with no improvement after + which parameters will be reduced. For example, if + ``patience = 2``, then we will ignore the first 2 epochs + with no improvement, and will only decrease the parameters after + the 3rd epoch if the monitor value still hasn't improved then. + Defaults to 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Defaults to 1e-4. + threshold_rule (str): One of `rel`, `abs`. In `rel` rule, + dynamic_threshold = best * ( 1 + threshold ) in 'greater' + rule or best * ( 1 - threshold ) in `less` rule. + In `abs` rule, dynamic_threshold = best + threshold in + `greater` rule or best - threshold in `less` rule. + Defaults to 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after parameters have been reduced. Defaults to 0. + min_value (float or list[float]): A scalar or a sequence of scalars. + A lower bound on the parameters of each parameter group + respectively. Defaults to 0. . + eps (float): Minimal decay applied to parameters. If the difference + between new and old parameters are smaller than eps, the update is + ignored. Defaults to 1e-8. + begin (int): Step at which to start triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to 0. + end (int): Step at which to stop triggering the scheduler + to monitor in val within the interval calculated + according to epoch of training. Defaults to INF. + last_step (int): The index of last step. Used for resume without + state dict. Defaults to -1. + by_epoch (bool): Whether the scheduled parameters are updated by + epochs. Defaults to True. + verbose (bool): Whether to print the value for each update. + Defaults to False. + """ + + need_val_args = True + + def __init__(self, + optimizer: OptimizerType, + param_name: str, + monitor: str = 'loss', + rule: str = 'less', + factor: float = 0.1, + patience: int = 10, + threshold: float = 1e-4, + threshold_rule: str = 'rel', + cooldown: int = 0, + min_value: Union[float, Sequence[float]] = 0., + eps: float = 1e-8, + begin: int = 0, + end: int = INF, + last_step: int = -1, + by_epoch: bool = True, + verbose: bool = False): + + # Attach optimizer + if not isinstance(optimizer, (Optimizer, OptimWrapper)): + raise TypeError('``optimizer`` should be an Optimizer,' + 'but got {}'.format(type(optimizer).__name__)) + self.optimizer = optimizer + self.param_name = param_name + + if end <= begin: + raise ValueError('end should be larger than begin, but got' + ' begin={}, end={}'.format(begin, end)) + self.begin = begin + self.end = end + + assert by_epoch, \ + f'Now {type(self).__name__} only support by_epoch=True' + self.by_epoch = by_epoch + + assert isinstance(last_step, int) and last_step >= -1 + # Initialize valid step count and base values + if last_step == -1: + for group in optimizer.param_groups: + # If the param is never be scheduled, record the current value + # as the initial value. + group.setdefault(f'initial_{param_name}', group[param_name]) + else: + for i, group in enumerate(optimizer.param_groups): + if f'initial_{param_name}' not in group: + raise KeyError( + f"param 'initial_{param_name}' is not specified " + 'in param_groups[{}] when resuming an optimizer'. + format(i)) + + self.last_step = last_step + + self._global_step = 0 + self.verbose = verbose + + if factor >= 1.0: + raise ValueError('Factor should be < 1.0.') + self.factor = factor + + if isinstance(min_value, (list, tuple)): + if len(min_value) != len(optimizer.param_groups): + raise ValueError('expected {} min_lrs, got {}'.format( + len(optimizer.param_groups), len(min_value))) + self.min_values = list(min_value) + else: + self.min_values = [min_value] * len( # type: ignore + optimizer.param_groups) + + self.patience = patience + self.cooldown = cooldown + self.cooldown_counter = 0 + self.rule_worse = None # the worse value for the chosen mode + self.best = None + self.num_bad_epochs = 0 + self.eps = eps + + self.monitor = monitor + self._init_is_better( + rule=rule, threshold=threshold, threshold_rule=threshold_rule) + self._reset() + + # remove call self.step() and init self._global_step = 0 + self._last_value = [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + def step(self, metrics=None): + """Adjusts the parameter value of each parameter group based on the + specified schedule. + + Args: + metrics (Dict[str, float], optional): Evaluation results of all + metrics on validation dataset. The keys are the names of the + metrics, and the values are corresponding results. + Defaults to None. + """ + if metrics is None: + # only to count self._global_step + self._global_step += 1 + return + + if not isinstance(metrics, dict): + raise TypeError('metrics type should be dict,' + f' but got type {type(metrics)}') + + # Compute parameter value per param group in the effective range + if self.begin <= self._global_step < self.end: + self.last_step += 1 + + # convert `metric` to float, in case it's a zero-dim Tensor + metric = metrics.get(self.monitor, None) + if metric is not None: + if self._is_better(metric, self.best): + self.best = metric + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self._in_cooldown(): + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + values = self._get_value() + + for i, data in enumerate( + zip(self.optimizer.param_groups, values)): + param_group, value = data + if param_group[self.param_name] - value > self.eps: + param_group[self.param_name] = value + self.print_value(self.verbose, i, value) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + else: + raise KeyError(f'Excepted key in {list(metrics.keys())},' + f' but got key {self.monitor} is not in dict') + + self._last_value = [ + group[self.param_name] for group in self.optimizer.param_groups + ] + + def print_value(self, is_verbose: bool, group: int, value: float) -> None: + """Display the current parameter value. + + Args: + is_verbose (bool): Whether to print the value. + group (int): The index of the current ``param_group``. + value (float): The parameter value. + """ + if is_verbose: + step_name = 'epoch' if self.by_epoch else 'iter' + print_log( + f'Adjusting parameter value of group {group} to {value:.4e} ' + f'in {step_name} {self.last_step}.', + logger='current') + + def _get_value(self): + """Compute value using chainable form of the scheduler.""" + values = [ + float(group[self.param_name]) * self.factor + for group in self.optimizer.param_groups + ] + return [max(v, min_v) for v, min_v in zip(values, self.min_values)] + + def _in_cooldown(self): + """Judge whether it is in cooldown.""" + return self.cooldown_counter > 0 + + def _is_better(self, a, best): + """Judge whether the monitor value is better.""" + if self.rule == 'less' and self.threshold_rule == 'rel': + rel_epsilon = 1. - self.threshold + return a < best * rel_epsilon + + elif self.rule == 'less' and self.threshold_rule == 'abs': + return a < best - self.threshold + + elif self.rule == 'greater' and self.threshold_rule == 'rel': + rel_epsilon = self.threshold + 1. + return a > best * rel_epsilon + + else: # rule == 'greater' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, rule, threshold, threshold_rule): + """Initialize rule and its associated values.""" + if threshold < 0: + raise ValueError(f'threshold {threshold} should be >= 0.') + if rule not in {'less', 'greater'}: + raise ValueError(f'mode {rule} is unknown!') + if threshold_rule not in {'rel', 'abs'}: + raise ValueError(f'threshold mode {threshold_rule}' + ' is unknown!') + + if rule == 'less': + self.rule_worse = INF + else: # rule == 'greater': + self.rule_worse = -INF + + self.rule = rule + self.threshold = threshold + self.threshold_rule = threshold_rule + + def _reset(self): + """Resets num_bad_epochs counter and cooldown counter.""" + self.best = self.rule_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 diff --git a/tests/test_hooks/test_param_scheduler_hook.py b/tests/test_hooks/test_param_scheduler_hook.py index 85c7fb2297c75c5ddb6f6ca14b14f37275cb7230..d4f6b18ee29177d56e66697e7679f5b50303e6d7 100644 --- a/tests/test_hooks/test_param_scheduler_hook.py +++ b/tests/test_hooks/test_param_scheduler_hook.py @@ -4,13 +4,14 @@ from unittest.mock import Mock import pytest from mmengine.hooks import ParamSchedulerHook +from mmengine.optim import _ParamScheduler class TestParamSchedulerHook: error_msg = ('runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler') - def test_after_iter(self): + def test_after_train_iter(self): # runner.param_schedulers should be a list or dict with pytest.raises(TypeError, match=self.error_msg): hook = ParamSchedulerHook() @@ -42,9 +43,10 @@ class TestParamSchedulerHook: runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_train_epoch(runner) hook.after_train_iter(runner, 0) + scheduler1.step.assert_called() scheduler2.step.assert_called() - def test_after_epoch(self): + def test_after_train_epoch(self): # runner.param_schedulers should be a list or dict with pytest.raises(TypeError, match=self.error_msg): hook = ParamSchedulerHook() @@ -53,7 +55,7 @@ class TestParamSchedulerHook: scheduler.step = Mock() scheduler.by_epoch = True runner.param_schedulers = scheduler - hook.after_train_iter(runner, 0) + hook.after_train_epoch(runner) scheduler.step.assert_called() # runner.param_schedulers is a list of schedulers @@ -77,3 +79,51 @@ class TestParamSchedulerHook: hook.after_train_epoch(runner) scheduler1.step.assert_called() scheduler2.step.assert_called() + + def test_after_val_epoch(self): + metrics = dict(loss=1.0) + + # mock super _ParamScheduler class + class MockParamScheduler(_ParamScheduler): + + def __init__(self): + pass + + def _get_value(self): + pass + + # runner.param_schedulers should be a list or dict + with pytest.raises(TypeError, match=self.error_msg): + hook = ParamSchedulerHook() + runner = Mock() + scheduler = Mock() + scheduler.step = Mock() + scheduler.by_epoch = True + scheduler.need_val_args = True + runner.param_schedulers = scheduler + hook.after_val_epoch(runner, metrics) + + # runner.param_schedulers is a list of schedulers + hook = ParamSchedulerHook() + runner = Mock() + scheduler = MockParamScheduler() + scheduler.step = Mock() + scheduler.by_epoch = True + scheduler.need_val_args = True + runner.param_schedulers = [scheduler] + hook.after_val_epoch(runner, metrics) + scheduler.step.assert_called_with(metrics) + + # runner.param_schedulers is a dict containing list of schedulers + scheduler1 = MockParamScheduler() + scheduler1.step = Mock() + scheduler1.by_epoch = True + scheduler1.need_val_args = True + scheduler2 = MockParamScheduler() + scheduler2.step = Mock() + scheduler2.by_epoch = True + scheduler2.need_val_args = True + runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) + hook.after_val_epoch(runner, metrics) + scheduler1.step.assert_called_with(metrics) + scheduler2.step.assert_called_with(metrics) diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index bd537b8680e2d0e4a8f72ddfc43f680cd5dbc0e1..c9cd6e1fe6f6209ccc95374e0d8fc049ee93e738 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -8,7 +8,8 @@ import torch.optim as optim from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR, ExponentialLR, LinearLR, - MultiStepLR, OneCycleLR, PolyLR, StepLR, + MultiStepLR, OneCycleLR, PolyLR, + ReduceOnPlateauLR, StepLR, _ParamScheduler) from mmengine.testing import assert_allclose @@ -195,9 +196,16 @@ class TestLRScheduler(TestCase): schedulers, targets, epochs=10, - param_name='lr'): + param_name='lr', + step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] + if step_kwargs is None: + step_kwarg = [{} for _ in range(len(schedulers))] + step_kwargs = [step_kwarg for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs + assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): for param_group, target in zip(self.optimizer.param_groups, targets): @@ -209,7 +217,10 @@ class TestLRScheduler(TestCase): param_group[param_name]), atol=1e-5, rtol=0) - [scheduler.step() for scheduler in schedulers] + [ + scheduler.step(**step_kwargs[epoch][i]) + for i, scheduler in enumerate(schedulers) + ] def test_step_scheduler(self): # lr = 0.05 if epoch < 3 @@ -361,11 +372,176 @@ class TestLRScheduler(TestCase): eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) - def _check_scheduler_state_dict(self, construct, construct2, epochs=10): + def test_reduce_on_plateau_scheduler(self): + # inherit _ParamScheduler but not call super().__init__(), + # so some codes need to be retested + + # Test error in __init__ method + with self.assertRaises(TypeError): + ReduceOnPlateauLR('invalid_optimizer') + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, begin=10, end=5) + with self.assertRaises(AssertionError): + ReduceOnPlateauLR(self.optimizer, by_epoch=False) + + for last_step in (1.5, -2): + with self.assertRaises(AssertionError): + ReduceOnPlateauLR(self.optimizer, last_step=last_step) + + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, factor=2.0) + ReduceOnPlateauLR(self.optimizer, min_value=[0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, threshold=-1.0) + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, rule='foo') + with self.assertRaises(ValueError): + ReduceOnPlateauLR(self.optimizer, threshold_rule='foo') + + # Test error in step method + scheduler = ReduceOnPlateauLR(self.optimizer, monitor='loss') + assert scheduler.step() is None + + with self.assertRaises(TypeError): + scheduler.step(('foo', 1.0)) + + metrics = dict(loss_foo=1.0) + with self.assertRaises(KeyError): + scheduler.step(metrics) + + # Test scheduler value + def _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, + min_value): + lr = 0.05 + momentum = 0.01 + weight_decay = 5e-4 + scheduler = ReduceOnPlateauLR( + self.optimizer, + monitor=monitor, + rule=rule, + factor=factor, + patience=patience, + threshold=threshold, + threshold_rule=threshold_rule, + cooldown=cooldown, + min_value=min_value, + ) + self._test_scheduler_value( + scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + + # reset the state of optimizers + self.optimizer = optim.SGD([{ + 'params': self.model.conv1.parameters() + }, { + 'params': self.model.conv2.parameters(), + 'lr': lr * self.layer2_mult, + }], + lr=lr, + momentum=momentum, + weight_decay=weight_decay) + + epochs = 10 + factor = 0.1 + cooldown = 1 + patience = 2 + + # rule(less) and threshold_rule(rel) + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(less) and threshold_rule(abs) + rule, threshold_rule = 'less', 'abs' + threshold = 0.9 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(greater) and threshold_rule(rel) + rule, threshold_rule = 'greater', 'rel' + threshold = 0.01 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(greater) and threshold_rule(abs) + rule, threshold_rule = 'greater', 'abs' + threshold = 0.9 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # change min_value + min_value = 0.01 + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets_1 = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, + min_value + ] + single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] + targets = [single_targets_1, single_targets_2] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, min_value) + + def _check_scheduler_state_dict(self, + construct, + construct2, + epochs=10, + step_kwargs=None): + if step_kwargs is None: + step_kwargs = [{} for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs scheduler = construct() - for _ in range(epochs): + for epoch in range(epochs): scheduler.optimizer.step() - scheduler.step() + scheduler.step(**step_kwargs[epoch]) scheduler_copy = construct2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): @@ -429,6 +605,35 @@ class TestLRScheduler(TestCase): eta_min=0), epochs=10) + def test_reduce_on_plateau_scheduler_state_dict(self): + epochs = 10 + metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] + self._check_scheduler_state_dict( + lambda: ReduceOnPlateauLR( + self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauLR( + self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), + epochs=epochs, + step_kwargs=metrics_list) + def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index d98f2fe424a2f915eb9c200b6b0d423e67bab15c..942259d7dad4d7a1ef56ba90f7bd9fdeb6e85c13 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -6,12 +6,15 @@ import torch import torch.nn.functional as F import torch.optim as optim +# yapf: disable from mmengine.optim.scheduler import (ConstantMomentum, CosineAnnealingMomentum, CosineRestartMomentum, ExponentialMomentum, LinearMomentum, MultiStepMomentum, PolyMomentum, - StepMomentum, _ParamScheduler) + ReduceOnPlateauMomentum, StepMomentum, + _ParamScheduler) +# yapf: enable from mmengine.testing import assert_allclose @@ -213,9 +216,16 @@ class TestMomentumScheduler(TestCase): schedulers, targets, epochs=10, - param_name='momentum'): + param_name='momentum', + step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] + if step_kwargs is None: + step_kwarg = [{} for _ in range(len(schedulers))] + step_kwargs = [step_kwarg for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs + assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): for param_group, target in zip(optimizer.param_groups, targets): assert_allclose( @@ -235,7 +245,10 @@ class TestMomentumScheduler(TestCase): param_group['betas'][0]), atol=1e-5, rtol=0) - [scheduler.step() for scheduler in schedulers] + [ + scheduler.step(**step_kwargs[epoch][i]) + for i, scheduler in enumerate(schedulers) + ] def test_step_scheduler(self): # momentum = 0.05 if epoch < 3 @@ -437,11 +450,204 @@ class TestMomentumScheduler(TestCase): self._test_scheduler_value( self.optimizer_with_betas, scheduler, targets, epochs=10) - def _check_scheduler_state_dict(self, construct, construct2, epochs=10): + def test_reduce_on_plateau_scheduler(self): + # inherit _ParamScheduler but not call super().__init__(), + # so some codes need to be retested + + # Test error in __init__ method + with self.assertRaises(ValueError): + optimizer = optim.ASGD( + self.model.parameters(), + lr=0.01, + ) + ReduceOnPlateauMomentum(optimizer) + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum(self.optimizer, begin=10, end=5) + with self.assertRaises(AssertionError): + ReduceOnPlateauMomentum(self.optimizer, by_epoch=False) + + for last_step in (1.5, -2): + with self.assertRaises(AssertionError): + ReduceOnPlateauMomentum(self.optimizer, last_step=last_step) + + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum(self.optimizer, factor=2.0) + ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum( + self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum(self.optimizer, threshold=-1.0) + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum(self.optimizer, rule='foo') + with self.assertRaises(ValueError): + ReduceOnPlateauMomentum(self.optimizer, threshold_rule='foo') + + # Test error in step method + scheduler = ReduceOnPlateauMomentum(self.optimizer, monitor='loss') + assert scheduler.step() is None + + with self.assertRaises(TypeError): + scheduler.step(('foo', 1.0)) + + metrics = dict(loss_foo=1.0) + with self.assertRaises(KeyError): + scheduler.step(metrics) + + # Test scheduler value + def _test_value(epochs, targets, metrics_list, optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, min_value): + lr = 0.01 + momentum = 0.05 + weight_decay = 5e-4 + scheduler = ReduceOnPlateauMomentum( + optimizer, + monitor=monitor, + rule=rule, + factor=factor, + patience=patience, + threshold=threshold, + threshold_rule=threshold_rule, + cooldown=cooldown, + min_value=min_value, + ) + self._test_scheduler_value( + optimizer, + scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) + + # reset the state of optimizers + self.optimizer = optim.SGD([{ + 'params': self.model.conv1.parameters() + }, { + 'params': self.model.conv2.parameters(), + 'momentum': momentum * self.layer2_mult + }], + lr=lr, + momentum=momentum, + weight_decay=weight_decay) + self.optimizer_with_betas = optim.Adam( + [{ + 'params': self.model.conv1.parameters() + }, { + 'params': self.model.conv2.parameters(), + 'betas': (momentum * self.layer2_mult, 0.999) + }], + lr=lr, + betas=(momentum, 0.999), + weight_decay=weight_decay) + + epochs = 10 + factor = 0.1 + cooldown = 1 + patience = 2 + + # rule(less) and threshold_rule(rel) + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, self.optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, 0.0) + + # rule(less) and threshold_rule(abs) + rule, threshold_rule = 'less', 'abs' + threshold = 0.9 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, self.optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, 0.0) + + # rule(greater) and threshold_rule(rel) + rule, threshold_rule = 'greater', 'rel' + threshold = 0.01 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, self.optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, 0.0) + + # rule(greater) and threshold_rule(abs) + rule, threshold_rule = 'greater', 'abs' + threshold = 0.9 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, self.optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, 0.0) + + # change min_value + min_value = 0.01 + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets_1 = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, + min_value + ] + single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] + targets = [single_targets_1, single_targets_2] + + _test_value(epochs, targets, metrics_list, self.optimizer, monitor, + rule, factor, patience, threshold, threshold_rule, + cooldown, min_value) + + _test_value(epochs, targets, metrics_list, self.optimizer_with_betas, + monitor, rule, factor, patience, threshold, threshold_rule, + cooldown, min_value) + + def _check_scheduler_state_dict(self, + construct, + construct2, + epochs=10, + step_kwargs=None): + if step_kwargs is None: + step_kwargs = [{} for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs scheduler = construct() - for _ in range(epochs): + for epoch in range(epochs): scheduler.optimizer.step() - scheduler.step() + scheduler.step(**step_kwargs[epoch]) scheduler_copy = construct2() scheduler_copy.load_state_dict(scheduler.state_dict()) for key in scheduler.__dict__.keys(): @@ -506,6 +712,35 @@ class TestMomentumScheduler(TestCase): eta_min=0), epochs=10) + def test_reduce_on_plateau_scheduler_state_dict(self): + epochs = 10 + metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] + self._check_scheduler_state_dict( + lambda: ReduceOnPlateauMomentum( + self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauMomentum( + self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), + epochs=epochs, + step_kwargs=metrics_list) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index ce86195b944fc9c696bdb30c596d718915ee5180..b210983143e5086f8074aa135f63d8c4b061b6ab 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -17,8 +17,9 @@ from mmengine.optim.scheduler import (ConstantParamScheduler, LinearParamScheduler, MultiStepParamScheduler, OneCycleParamScheduler, - PolyParamScheduler, StepParamScheduler, - _ParamScheduler) + PolyParamScheduler, + ReduceOnPlateauParamScheduler, + StepParamScheduler, _ParamScheduler) # yapf: enable from mmengine.testing import assert_allclose @@ -239,9 +240,16 @@ class TestParameterScheduler(TestCase): schedulers, targets, epochs=10, - param_name='lr'): + param_name='lr', + step_kwargs=None): if isinstance(schedulers, _ParamScheduler): schedulers = [schedulers] + if step_kwargs is None: + step_kwarg = [{} for _ in range(len(schedulers))] + step_kwargs = [step_kwarg for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs + assert len(step_kwargs[0]) == len(schedulers) for epoch in range(epochs): for param_group, target in zip(self.optimizer.param_groups, targets): @@ -253,7 +261,10 @@ class TestParameterScheduler(TestCase): param_group[param_name]), atol=1e-5, rtol=0) - [scheduler.step() for scheduler in schedulers] + [ + scheduler.step(**step_kwargs[epoch][i]) + for i, scheduler in enumerate(schedulers) + ] def test_step_scheduler(self): # lr = 0.05 if epoch < 3 @@ -488,11 +499,186 @@ class TestParameterScheduler(TestCase): eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs=10) - def _check_scheduler_state_dict(self, construct, construct2, epochs=10): + def test_reduce_on_plateau_scheduler(self): + # inherit _ParamScheduler but not call super().__init__(), + # so some codes need to be retested + + # Test error in __init__ method + with self.assertRaises(TypeError): + ReduceOnPlateauParamScheduler('invalid_optimizer', param_name='lr') + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler( + self.optimizer, 'lr', begin=10, end=5) + with self.assertRaises(AssertionError): + ReduceOnPlateauParamScheduler(self.optimizer, 'lr', by_epoch=False) + + for last_step in (1.5, -2): + with self.assertRaises(AssertionError): + ReduceOnPlateauParamScheduler( + self.optimizer, 'lr', last_step=last_step) + + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler(self.optimizer, 'lr', factor=2.0) + ReduceOnPlateauParamScheduler( + self.optimizer, 'lr', min_value=[0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler( + self.optimizer, 'lr', min_value=[0.1, 0.1, 0.1, 0.1]) + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler(self.optimizer, 'lr', threshold=-1.0) + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler(self.optimizer, 'lr', rule='foo') + with self.assertRaises(ValueError): + ReduceOnPlateauParamScheduler( + self.optimizer, 'lr', threshold_rule='foo') + + # Test error in step method + scheduler = ReduceOnPlateauParamScheduler( + self.optimizer, param_name='lr', monitor='loss') + assert scheduler.step() is None + + with self.assertRaises(TypeError): + scheduler.step(('foo', 1.0)) + + metrics = dict(loss_foo=1.0) + with self.assertRaises(KeyError): + scheduler.step(metrics) + + # Test scheduler value + def _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, + min_value): + lr = 0.05 + momentum = 0.01 + weight_decay = 5e-4 + scheduler = ReduceOnPlateauParamScheduler( + self.optimizer, + param_name='lr', + monitor=monitor, + rule=rule, + factor=factor, + patience=patience, + threshold=threshold, + threshold_rule=threshold_rule, + cooldown=cooldown, + min_value=min_value, + ) + self._test_scheduler_value( + scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + + # reset the state of optimizers + self.optimizer = optim.SGD( + [{ + 'params': self.model.conv1.parameters() + }, { + 'params': self.model.conv2.parameters(), + 'lr': lr * self.layer2_mult, + 'momentum': momentum * self.layer2_mult, + 'weight_decay': weight_decay * self.layer2_mult + }], + lr=lr, + momentum=momentum, + weight_decay=weight_decay) + + epochs = 10 + factor = 0.1 + cooldown = 1 + patience = 2 + + # rule(less) and threshold_rule(rel) + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(less) and threshold_rule(abs) + rule, threshold_rule = 'less', 'abs' + threshold = 0.9 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(greater) and threshold_rule(rel) + rule, threshold_rule = 'greater', 'rel' + threshold = 0.01 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # rule(greater) and threshold_rule(abs) + rule, threshold_rule = 'greater', 'abs' + threshold = 0.9 + monitor = 'bbox_mAP' + metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005 + ] + targets = [ + single_targets, [t * self.layer2_mult for t in single_targets] + ] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, 0.0) + + # change min_value + min_value = 0.01 + rule, threshold_rule = 'less', 'rel' + threshold = 0.01 + monitor = 'loss' + metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.] + metrics_list = [[dict(metrics={monitor: v})] for v in metric_values] + single_targets_1 = [ + 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value, + min_value + ] + single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05] + targets = [single_targets_1, single_targets_2] + + _test_value(epochs, targets, metrics_list, monitor, rule, factor, + patience, threshold, threshold_rule, cooldown, min_value) + + def _check_scheduler_state_dict(self, + construct, + construct2, + epochs=10, + step_kwargs=None): + if step_kwargs is None: + step_kwargs = [{} for _ in range(epochs)] + else: # step_kwargs is not None + assert len(step_kwargs) == epochs scheduler = construct() - for _ in range(epochs): + for epoch in range(epochs): scheduler.optimizer.step() - scheduler.step() + scheduler.step(**step_kwargs[epoch]) scheduler_copy = construct2() torch.save(scheduler.state_dict(), osp.join(self.temp_dir.name, 'tmp.pth')) @@ -581,6 +767,37 @@ class TestParameterScheduler(TestCase): eta_min=0), epochs=10) + def test_reduce_on_plateau_scheduler_state_dict(self): + epochs = 10 + metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] + self._check_scheduler_state_dict( + lambda: ReduceOnPlateauParamScheduler( + self.optimizer, + param_name='lr', + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauParamScheduler( + self.optimizer, + param_name='lr', + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), + epochs=epochs, + step_kwargs=metrics_list) + def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError):