diff --git a/docs/en/api/optim.rst b/docs/en/api/optim.rst
index 634884c9f522d437fbe8e50c2515d9218848c8b7..3de2ade6a2e3b1c7aa7c74a6629991c79c052f92 100644
--- a/docs/en/api/optim.rst
+++ b/docs/en/api/optim.rst
@@ -63,3 +63,6 @@ Scheduler
    StepLR
    StepMomentum
    StepParamScheduler
+   ReduceOnPlateauLR
+   ReduceOnPlateauMomentum
+   ReduceOnPlateauParamScheduler
diff --git a/docs/zh_cn/api/optim.rst b/docs/zh_cn/api/optim.rst
index 634884c9f522d437fbe8e50c2515d9218848c8b7..3de2ade6a2e3b1c7aa7c74a6629991c79c052f92 100644
--- a/docs/zh_cn/api/optim.rst
+++ b/docs/zh_cn/api/optim.rst
@@ -63,3 +63,6 @@ Scheduler
    StepLR
    StepMomentum
    StepParamScheduler
+   ReduceOnPlateauLR
+   ReduceOnPlateauMomentum
+   ReduceOnPlateauParamScheduler
diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py
index cb033ce4a56f524d2eab41fbaeff077968fd7891..3b2f1e610a9db058e99d03faf607514d73bce030 100644
--- a/mmengine/hooks/param_scheduler_hook.py
+++ b/mmengine/hooks/param_scheduler_hook.py
@@ -1,7 +1,9 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Union
+from typing import Dict, Optional, Union
 
+from mmengine.optim import _ParamScheduler
 from mmengine.registry import HOOKS
+from mmengine.utils import is_list_of
 from .hook import Hook
 
 DATA_BATCH = Optional[Union[dict, tuple, list]]
@@ -19,7 +21,7 @@ class ParamSchedulerHook(Hook):
                          batch_idx: int,
                          data_batch: DATA_BATCH = None,
                          outputs: Optional[dict] = None) -> None:
-        """Call step function for each scheduler after each iteration.
+        """Call step function for each scheduler after each training iteration.
 
         Args:
             runner (Runner): The runner of the training process.
@@ -32,15 +34,15 @@ class ParamSchedulerHook(Hook):
                 keep ``data_batch`` here.
         """
 
+        if runner.param_schedulers is None:
+            return
+
         def step(param_schedulers):
             assert isinstance(param_schedulers, list)
             for scheduler in param_schedulers:
                 if not scheduler.by_epoch:
                     scheduler.step()
 
-        if runner.param_schedulers is None:
-            return
-
         if isinstance(runner.param_schedulers, list):
             step(runner.param_schedulers)
         elif isinstance(runner.param_schedulers, dict):
@@ -53,21 +55,67 @@ class ParamSchedulerHook(Hook):
                 f'but got {runner.param_schedulers}')
 
     def after_train_epoch(self, runner) -> None:
-        """Call step function for each scheduler after each epoch.
+        """Call step function for each scheduler after each training epoch.
 
         Args:
             runner (Runner): The runner of the training process.
         """
 
+        if runner.param_schedulers is None:
+            return
+
         def step(param_schedulers):
             assert isinstance(param_schedulers, list)
             for scheduler in param_schedulers:
                 if scheduler.by_epoch:
                     scheduler.step()
 
+        if isinstance(runner.param_schedulers, list):
+            step(runner.param_schedulers)
+        elif isinstance(runner.param_schedulers, dict):
+            for param_schedulers in runner.param_schedulers.values():
+                step(param_schedulers)
+        else:
+            raise TypeError(
+                'runner.param_schedulers should be list of ParamScheduler or '
+                'a dict containing list of ParamScheduler, '
+                f'but got {runner.param_schedulers}')
+
+    def after_val_epoch(self,
+                        runner,
+                        metrics: Optional[Dict[str, float]] = None) -> None:
+        """Call step function for each scheduler which has attribute
+        ``need_val_args`` after each validation epoch.
+
+        Args:
+            runner (Runner): The runner of the validation process.
+            metrics (Dict[str, float], optional): Evaluation results of all
+                metrics on validation dataset. The keys are the names of the
+                metrics, and the values are corresponding results.
+
+        Note:
+            if ``runner.param_schedulers`` is not built before,
+            the hook ``after_val_epoch`` will be skipped.
+        """
+
         if runner.param_schedulers is None:
             return
 
+        # avoid counting scheduler._global_step
+        # it has counted in after_train_* hook
+        if metrics is None:
+            return
+
+        def step(param_schedulers):
+            # check param_schedulers is list and built
+            if not is_list_of(param_schedulers, _ParamScheduler):
+                return
+
+            for scheduler in param_schedulers:
+                if (scheduler.by_epoch
+                        and getattr(scheduler, 'need_val_args', False)):
+                    scheduler.step(metrics)
+
         if isinstance(runner.param_schedulers, list):
             step(runner.param_schedulers)
         elif isinstance(runner.param_schedulers, dict):
diff --git a/mmengine/optim/__init__.py b/mmengine/optim/__init__.py
index 9b34441ef45b4de0a618634cc2d5cc73f1b038f2..72118b179f3f07e2ee4fa0f1cceb1e1a3c394332 100644
--- a/mmengine/optim/__init__.py
+++ b/mmengine/optim/__init__.py
@@ -11,8 +11,10 @@ from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
                         MultiStepLR, MultiStepMomentum,
                         MultiStepParamScheduler, OneCycleLR,
                         OneCycleParamScheduler, PolyLR, PolyMomentum,
-                        PolyParamScheduler, StepLR, StepMomentum,
-                        StepParamScheduler, _ParamScheduler)
+                        PolyParamScheduler, ReduceOnPlateauLR,
+                        ReduceOnPlateauMomentum, ReduceOnPlateauParamScheduler,
+                        StepLR, StepMomentum, StepParamScheduler,
+                        _ParamScheduler)
 
 # yapf: enable
 __all__ = [
@@ -25,5 +27,6 @@ __all__ = [
     'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
     '_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
     'OneCycleParamScheduler', 'OneCycleLR', 'PolyLR', 'PolyMomentum',
-    'PolyParamScheduler'
+    'PolyParamScheduler', 'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum',
+    'ReduceOnPlateauParamScheduler'
 ]
diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py
index c12d1c2970d39561e3ddec2e92e49f67bbfdc4ba..48ccc34bc41b07442e2494b03a303b3c0054b42b 100644
--- a/mmengine/optim/scheduler/__init__.py
+++ b/mmengine/optim/scheduler/__init__.py
@@ -2,21 +2,22 @@
 # yapf: disable
 from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
                            ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
-                           PolyLR, StepLR)
+                           PolyLR, ReduceOnPlateauLR, StepLR)
 from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
                                  CosineRestartMomentum, ExponentialMomentum,
                                  LinearMomentum, MultiStepMomentum,
-                                 PolyMomentum, StepMomentum)
+                                 PolyMomentum, ReduceOnPlateauMomentum,
+                                 StepMomentum)
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
                               CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, OneCycleParamScheduler,
-                              PolyParamScheduler, StepParamScheduler,
-                              _ParamScheduler)
+                              PolyParamScheduler,
+                              ReduceOnPlateauParamScheduler,
+                              StepParamScheduler, _ParamScheduler)
 
 # yapf: enable
-
 __all__ = [
     'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
     'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
@@ -26,5 +27,6 @@ __all__ = [
     'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
     'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
     'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
-    'CosineRestartMomentum'
+    'CosineRestartMomentum', 'ReduceOnPlateauParamScheduler',
+    'ReduceOnPlateauLR', 'ReduceOnPlateauMomentum'
 ]
diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py
index b1eeabec5f794e8130158ac710907f33827c8f77..08b98ee76fb444c0ed648dd07f5a8ce3cb7837b3 100644
--- a/mmengine/optim/scheduler/lr_scheduler.py
+++ b/mmengine/optim/scheduler/lr_scheduler.py
@@ -1,11 +1,16 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from mmengine.registry import PARAM_SCHEDULERS
+# yapf: disable
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
                               CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, OneCycleParamScheduler,
-                              PolyParamScheduler, StepParamScheduler)
+                              PolyParamScheduler,
+                              ReduceOnPlateauParamScheduler,
+                              StepParamScheduler)
+
+# yapf: enable
 
 
 class LRSchedulerMixin:
@@ -314,3 +319,60 @@ class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
         verbose (bool): Whether to print the value for each update.
             Defaults to False.
     """
+
+
+@PARAM_SCHEDULERS.register_module()
+class ReduceOnPlateauLR(LRSchedulerMixin, ReduceOnPlateauParamScheduler):
+    """Reduce the learning rate of each parameter group when a metric has
+    stopped improving. Models often benefit from reducing the learning rate by
+    a factor of 2-10 once learning stagnates. This scheduler reads a metrics
+    quantity and if no improvement is seen for a ``patience`` number of epochs,
+    the learning rate is reduced.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        monitor (str): Key name of the value to monitor in metrics dict.
+        rule (str): One of `less`, `greater`. In `less` rule, learning rate
+            will be reduced when the quantity monitored has stopped
+            decreasing; in `greater` rule it will be reduced when the
+            quantity monitored has stopped increasing. Defaults to 'less'.
+            The ``rule`` is the renaming of ``mode`` in pytorch.
+        factor (float): Factor by which the learning rate will be
+            reduced. new_param = param * factor. Defaults to 0.1.
+        patience (int): Number of epochs with no improvement after
+            which learning rate will be reduced. For example, if
+            ``patience = 2``, then we will ignore the first 2 epochs
+            with no improvement, and will only decrease the learning rate after
+            the 3rd epoch if the monitor value still hasn't improved then.
+            Defaults to 10.
+        threshold (float): Threshold for measuring the new optimum,
+            to only focus on significant changes. Defaults to 1e-4.
+        threshold_rule (str): One of `rel`, `abs`. In `rel` rule,
+            dynamic_threshold = best * ( 1 + threshold ) in 'greater'
+            rule or best * ( 1 - threshold ) in `less` rule.
+            In `abs` rule, dynamic_threshold = best + threshold in
+            `greater` rule or best - threshold in `less` rule.
+            Defaults to 'rel'.
+        cooldown (int): Number of epochs to wait before resuming
+            normal operation after learning rate has been reduced.
+            Defaults to 0.
+        min_value (float or list[float]): A scalar or a sequence of scalars.
+            A lower bound on the learning rate of each parameter group
+            respectively. Defaults to 0. .
+        eps (float): Minimal decay applied to learning rate. If the difference
+            between new and old learning rate is smaller than eps, the update
+            is ignored. Defaults to 1e-8.
+        begin (int): Step at which to start triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to 0.
+        end (int): Step at which to stop triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py
index 102b173146f525a84ad5892d5e544b6ea44c2248..b15c6c914f100e89ee158df650f938cef5df57d3 100644
--- a/mmengine/optim/scheduler/momentum_scheduler.py
+++ b/mmengine/optim/scheduler/momentum_scheduler.py
@@ -1,12 +1,16 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from mmengine.registry import PARAM_SCHEDULERS
+# yapf: disable
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
                               CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, PolyParamScheduler,
+                              ReduceOnPlateauParamScheduler,
                               StepParamScheduler)
 
+# yapf: enable
+
 
 class MomentumSchedulerMixin:
     """A mixin class for momentum schedulers.
@@ -32,8 +36,8 @@ class MomentumSchedulerMixin:
         super().__init__(optimizer, param_name, *args, **kwargs)
 
     def step(self):
-        """Adjusts the parameter value of each parameter group based on the
-        specified schedule."""
+        """Adjusts the momentum of each parameter group based on the specified
+        schedule."""
         super().step()
         if self.use_betas:
             for group in self.optimizer.param_groups:
@@ -281,3 +285,77 @@ class CosineRestartMomentum(MomentumSchedulerMixin,
         verbose (bool): Whether to print the value for each update.
             Defaults to False.
     """
+
+
+@PARAM_SCHEDULERS.register_module()
+class ReduceOnPlateauMomentum(MomentumSchedulerMixin,
+                              ReduceOnPlateauParamScheduler):
+    """Reduce the momentum of each parameter group when a metric has stopped
+    improving. Models often benefit from reducing the momentum by a factor of
+    2-10 once learning stagnates. This scheduler reads a metrics quantity and
+    if no improvement is seen for a ``patience`` number of epochs, the momentum
+    is reduced.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        monitor (str): Key name of the value to monitor in metrics dict.
+        rule (str): One of `less`, `greater`. In `less` rule, momentum will
+            be reduced when the quantity monitored has stopped
+            decreasing; in `greater` rule it will be reduced when the
+            quantity monitored has stopped increasing. Defaults to 'less'.
+            The ``rule`` is the renaming of ``mode`` in pytorch.
+        factor (float): Factor by which the momentum will be
+            reduced. new_param = param * factor. Defaults to 0.1.
+        patience (int): Number of epochs with no improvement after
+            which momentum will be reduced. For example, if
+            ``patience = 2``, then we will ignore the first 2 epochs
+            with no improvement, and will only decrease the momentum after
+            the 3rd epoch if the monitor value still hasn't improved then.
+            Defaults to 10.
+        threshold (float): Threshold for measuring the new optimum,
+            to only focus on significant changes. Defaults to 1e-4.
+        threshold_rule (str): One of `rel`, `abs`. In `rel` rule,
+            dynamic_threshold = best * ( 1 + threshold ) in 'greater'
+            rule or best * ( 1 - threshold ) in `less` rule.
+            In `abs` rule, dynamic_threshold = best + threshold in
+            `greater` rule or best - threshold in `less` rule.
+            Defaults to 'rel'.
+        cooldown (int): Number of epochs to wait before resuming
+            normal operation after momentum has been reduced. Defaults to 0.
+        min_value (float or list[float]): A scalar or a sequence of scalars.
+            A lower bound on the momentum of each parameter group
+            respectively. Defaults to 0. .
+        eps (float): Minimal decay applied to momentum. If the difference
+            between new and old momentum is smaller than eps, the update is
+            ignored. Defaults to 1e-8.
+        begin (int): Step at which to start triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to 0.
+        end (int): Step at which to stop triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
+
+    def step(self, metrics=None):
+        """Adjusts the momentum of each parameter group based on the specified
+        schedule.
+
+        Args:
+            metrics (Dict[str, float], optional): Evaluation results of all
+                metrics on validation dataset. The keys are the names of the
+                metrics, and the values are corresponding results.
+                Defaults to None.
+        """
+        super(MomentumSchedulerMixin, self).step(metrics)
+        if self.use_betas:
+            for group in self.optimizer.param_groups:
+                _, beta_1 = group['betas']
+                # update the betas with the calculated value
+                group['betas'] = (group['momentum'], beta_1)
diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py
index 7fdf5c3d8874ad107703fe2840e09f0fee08fbe7..1bd15e9c4fe61385048da74b7493c6c28a1362f4 100644
--- a/mmengine/optim/scheduler/param_scheduler.py
+++ b/mmengine/optim/scheduler/param_scheduler.py
@@ -951,7 +951,7 @@ class OneCycleParamScheduler(_ParamScheduler):
 
     .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
         https://arxiv.org/abs/1708.07120
-    """# noqa E501
+    """  # noqa E501
 
     def __init__(self,
                  optimizer: Union[Optimizer, OptimWrapper],
@@ -1058,7 +1058,7 @@ class OneCycleParamScheduler(_ParamScheduler):
             if len(param) != len(optimizer.param_groups):
                 raise ValueError(
                     f'expected {len(optimizer.param_groups)} values '
-                    f'for {name}, got { len(param)}')
+                    f'for {name}, got {len(param)}')
             return param
         else:
             return [param] * len(optimizer.param_groups)
@@ -1283,3 +1283,278 @@ class CosineRestartParamScheduler(_ParamScheduler):
             if iteration < period:
                 return i
         return None
+
+
+@PARAM_SCHEDULERS.register_module()
+class ReduceOnPlateauParamScheduler(_ParamScheduler):
+    """Reduce the parameters of each parameter group when a metric has stopped
+    improving. Models often benefit from reducing the parameters by a factor of
+    2-10 once learning stagnates. This scheduler reads a metrics quantity and
+    if no improvement is seen for a ``patience`` number of epochs, the
+    parameters are reduced.
+
+    The implementation is motivated by
+    `PyTorch
+    ReduceLROnPlateau<https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py>`_.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
+        monitor (str): The name of the metric to measure whether
+            the performance of the model is improved.
+        rule (str): One of `less`, `greater`. In `less` rule, parameters will
+            be reduced when the quantity monitored has stopped
+            decreasing; in `greater` rule it will be reduced when the
+            quantity monitored has stopped increasing. Defaults to 'less'.
+            The ``rule`` is the renaming of ``mode`` in pytorch.
+        factor (float): Factor by which the parameters will be
+            reduced. new_param = param * factor. Defaults to 0.1.
+        patience (int): Number of epochs with no improvement after
+            which parameters will be reduced. For example, if
+            ``patience = 2``, then we will ignore the first 2 epochs
+            with no improvement, and will only decrease the parameters after
+            the 3rd epoch if the monitor value still hasn't improved then.
+            Defaults to 10.
+        threshold (float): Threshold for measuring the new optimum,
+            to only focus on significant changes. Defaults to 1e-4.
+        threshold_rule (str): One of `rel`, `abs`. In `rel` rule,
+            dynamic_threshold = best * ( 1 + threshold ) in 'greater'
+            rule or best * ( 1 - threshold ) in `less` rule.
+            In `abs` rule, dynamic_threshold = best + threshold in
+            `greater` rule or best - threshold in `less` rule.
+            Defaults to 'rel'.
+        cooldown (int): Number of epochs to wait before resuming
+            normal operation after parameters have been reduced. Defaults to 0.
+        min_value (float or list[float]): A scalar or a sequence of scalars.
+            A lower bound on the parameters of each parameter group
+            respectively. Defaults to 0. .
+        eps (float): Minimal decay applied to parameters. If the difference
+            between new and old parameters are smaller than eps, the update is
+            ignored. Defaults to 1e-8.
+        begin (int): Step at which to start triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to 0.
+        end (int): Step at which to stop triggering the scheduler
+            to monitor in val within the interval calculated
+            according to epoch of training. Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
+
+    need_val_args = True
+
+    def __init__(self,
+                 optimizer: OptimizerType,
+                 param_name: str,
+                 monitor: str = 'loss',
+                 rule: str = 'less',
+                 factor: float = 0.1,
+                 patience: int = 10,
+                 threshold: float = 1e-4,
+                 threshold_rule: str = 'rel',
+                 cooldown: int = 0,
+                 min_value: Union[float, Sequence[float]] = 0.,
+                 eps: float = 1e-8,
+                 begin: int = 0,
+                 end: int = INF,
+                 last_step: int = -1,
+                 by_epoch: bool = True,
+                 verbose: bool = False):
+
+        # Attach optimizer
+        if not isinstance(optimizer, (Optimizer, OptimWrapper)):
+            raise TypeError('``optimizer`` should be an Optimizer,'
+                            'but got {}'.format(type(optimizer).__name__))
+        self.optimizer = optimizer
+        self.param_name = param_name
+
+        if end <= begin:
+            raise ValueError('end should be larger than begin, but got'
+                             ' begin={}, end={}'.format(begin, end))
+        self.begin = begin
+        self.end = end
+
+        assert by_epoch, \
+            f'Now {type(self).__name__} only support by_epoch=True'
+        self.by_epoch = by_epoch
+
+        assert isinstance(last_step, int) and last_step >= -1
+        # Initialize valid step count and base values
+        if last_step == -1:
+            for group in optimizer.param_groups:
+                # If the param is never be scheduled, record the current value
+                # as the initial value.
+                group.setdefault(f'initial_{param_name}', group[param_name])
+        else:
+            for i, group in enumerate(optimizer.param_groups):
+                if f'initial_{param_name}' not in group:
+                    raise KeyError(
+                        f"param 'initial_{param_name}' is not specified "
+                        'in param_groups[{}] when resuming an optimizer'.
+                        format(i))
+
+        self.last_step = last_step
+
+        self._global_step = 0
+        self.verbose = verbose
+
+        if factor >= 1.0:
+            raise ValueError('Factor should be < 1.0.')
+        self.factor = factor
+
+        if isinstance(min_value, (list, tuple)):
+            if len(min_value) != len(optimizer.param_groups):
+                raise ValueError('expected {} min_lrs, got {}'.format(
+                    len(optimizer.param_groups), len(min_value)))
+            self.min_values = list(min_value)
+        else:
+            self.min_values = [min_value] * len(  # type: ignore
+                optimizer.param_groups)
+
+        self.patience = patience
+        self.cooldown = cooldown
+        self.cooldown_counter = 0
+        self.rule_worse = None  # the worse value for the chosen mode
+        self.best = None
+        self.num_bad_epochs = 0
+        self.eps = eps
+
+        self.monitor = monitor
+        self._init_is_better(
+            rule=rule, threshold=threshold, threshold_rule=threshold_rule)
+        self._reset()
+
+        # remove call self.step() and init self._global_step = 0
+        self._last_value = [
+            group[self.param_name] for group in self.optimizer.param_groups
+        ]
+
+    def step(self, metrics=None):
+        """Adjusts the parameter value of each parameter group based on the
+        specified schedule.
+
+        Args:
+            metrics (Dict[str, float], optional): Evaluation results of all
+                metrics on validation dataset. The keys are the names of the
+                metrics, and the values are corresponding results.
+                Defaults to None.
+        """
+        if metrics is None:
+            # only to count self._global_step
+            self._global_step += 1
+            return
+
+        if not isinstance(metrics, dict):
+            raise TypeError('metrics type should be dict,'
+                            f' but got type {type(metrics)}')
+
+        # Compute parameter value per param group in the effective range
+        if self.begin <= self._global_step < self.end:
+            self.last_step += 1
+
+            # convert `metric` to float, in case it's a zero-dim Tensor
+            metric = metrics.get(self.monitor, None)
+            if metric is not None:
+                if self._is_better(metric, self.best):
+                    self.best = metric
+                    self.num_bad_epochs = 0
+                else:
+                    self.num_bad_epochs += 1
+
+                if self._in_cooldown():
+                    self.cooldown_counter -= 1
+                    self.num_bad_epochs = 0  # ignore bad epochs in cooldown
+
+                if self.num_bad_epochs > self.patience:
+                    values = self._get_value()
+
+                    for i, data in enumerate(
+                            zip(self.optimizer.param_groups, values)):
+                        param_group, value = data
+                        if param_group[self.param_name] - value > self.eps:
+                            param_group[self.param_name] = value
+                            self.print_value(self.verbose, i, value)
+                    self.cooldown_counter = self.cooldown
+                    self.num_bad_epochs = 0
+
+            else:
+                raise KeyError(f'Excepted key in {list(metrics.keys())},'
+                               f' but got key {self.monitor} is not in dict')
+
+        self._last_value = [
+            group[self.param_name] for group in self.optimizer.param_groups
+        ]
+
+    def print_value(self, is_verbose: bool, group: int, value: float) -> None:
+        """Display the current parameter value.
+
+        Args:
+            is_verbose (bool): Whether to print the value.
+            group (int): The index of the current ``param_group``.
+            value (float): The parameter value.
+        """
+        if is_verbose:
+            step_name = 'epoch' if self.by_epoch else 'iter'
+            print_log(
+                f'Adjusting parameter value of group {group} to {value:.4e} '
+                f'in {step_name} {self.last_step}.',
+                logger='current')
+
+    def _get_value(self):
+        """Compute value using chainable form of the scheduler."""
+        values = [
+            float(group[self.param_name]) * self.factor
+            for group in self.optimizer.param_groups
+        ]
+        return [max(v, min_v) for v, min_v in zip(values, self.min_values)]
+
+    def _in_cooldown(self):
+        """Judge whether it is in cooldown."""
+        return self.cooldown_counter > 0
+
+    def _is_better(self, a, best):
+        """Judge whether the monitor value is better."""
+        if self.rule == 'less' and self.threshold_rule == 'rel':
+            rel_epsilon = 1. - self.threshold
+            return a < best * rel_epsilon
+
+        elif self.rule == 'less' and self.threshold_rule == 'abs':
+            return a < best - self.threshold
+
+        elif self.rule == 'greater' and self.threshold_rule == 'rel':
+            rel_epsilon = self.threshold + 1.
+            return a > best * rel_epsilon
+
+        else:  # rule == 'greater' and epsilon_mode == 'abs':
+            return a > best + self.threshold
+
+    def _init_is_better(self, rule, threshold, threshold_rule):
+        """Initialize rule and its associated values."""
+        if threshold < 0:
+            raise ValueError(f'threshold {threshold} should be >= 0.')
+        if rule not in {'less', 'greater'}:
+            raise ValueError(f'mode {rule} is unknown!')
+        if threshold_rule not in {'rel', 'abs'}:
+            raise ValueError(f'threshold mode {threshold_rule}'
+                             ' is unknown!')
+
+        if rule == 'less':
+            self.rule_worse = INF
+        else:  # rule == 'greater':
+            self.rule_worse = -INF
+
+        self.rule = rule
+        self.threshold = threshold
+        self.threshold_rule = threshold_rule
+
+    def _reset(self):
+        """Resets num_bad_epochs counter and cooldown counter."""
+        self.best = self.rule_worse
+        self.cooldown_counter = 0
+        self.num_bad_epochs = 0
diff --git a/tests/test_hooks/test_param_scheduler_hook.py b/tests/test_hooks/test_param_scheduler_hook.py
index 85c7fb2297c75c5ddb6f6ca14b14f37275cb7230..d4f6b18ee29177d56e66697e7679f5b50303e6d7 100644
--- a/tests/test_hooks/test_param_scheduler_hook.py
+++ b/tests/test_hooks/test_param_scheduler_hook.py
@@ -4,13 +4,14 @@ from unittest.mock import Mock
 import pytest
 
 from mmengine.hooks import ParamSchedulerHook
+from mmengine.optim import _ParamScheduler
 
 
 class TestParamSchedulerHook:
     error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
                  'a dict containing list of ParamScheduler')
 
-    def test_after_iter(self):
+    def test_after_train_iter(self):
         # runner.param_schedulers should be a list or dict
         with pytest.raises(TypeError, match=self.error_msg):
             hook = ParamSchedulerHook()
@@ -42,9 +43,10 @@ class TestParamSchedulerHook:
         runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2])
         hook.after_train_epoch(runner)
         hook.after_train_iter(runner, 0)
+        scheduler1.step.assert_called()
         scheduler2.step.assert_called()
 
-    def test_after_epoch(self):
+    def test_after_train_epoch(self):
         # runner.param_schedulers should be a list or dict
         with pytest.raises(TypeError, match=self.error_msg):
             hook = ParamSchedulerHook()
@@ -53,7 +55,7 @@ class TestParamSchedulerHook:
             scheduler.step = Mock()
             scheduler.by_epoch = True
             runner.param_schedulers = scheduler
-            hook.after_train_iter(runner, 0)
+            hook.after_train_epoch(runner)
             scheduler.step.assert_called()
 
         # runner.param_schedulers is a list of schedulers
@@ -77,3 +79,51 @@ class TestParamSchedulerHook:
         hook.after_train_epoch(runner)
         scheduler1.step.assert_called()
         scheduler2.step.assert_called()
+
+    def test_after_val_epoch(self):
+        metrics = dict(loss=1.0)
+
+        # mock super _ParamScheduler class
+        class MockParamScheduler(_ParamScheduler):
+
+            def __init__(self):
+                pass
+
+            def _get_value(self):
+                pass
+
+        # runner.param_schedulers should be a list or dict
+        with pytest.raises(TypeError, match=self.error_msg):
+            hook = ParamSchedulerHook()
+            runner = Mock()
+            scheduler = Mock()
+            scheduler.step = Mock()
+            scheduler.by_epoch = True
+            scheduler.need_val_args = True
+            runner.param_schedulers = scheduler
+            hook.after_val_epoch(runner, metrics)
+
+        # runner.param_schedulers is a list of schedulers
+        hook = ParamSchedulerHook()
+        runner = Mock()
+        scheduler = MockParamScheduler()
+        scheduler.step = Mock()
+        scheduler.by_epoch = True
+        scheduler.need_val_args = True
+        runner.param_schedulers = [scheduler]
+        hook.after_val_epoch(runner, metrics)
+        scheduler.step.assert_called_with(metrics)
+
+        # runner.param_schedulers is a dict containing list of schedulers
+        scheduler1 = MockParamScheduler()
+        scheduler1.step = Mock()
+        scheduler1.by_epoch = True
+        scheduler1.need_val_args = True
+        scheduler2 = MockParamScheduler()
+        scheduler2.step = Mock()
+        scheduler2.by_epoch = True
+        scheduler2.need_val_args = True
+        runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2])
+        hook.after_val_epoch(runner, metrics)
+        scheduler1.step.assert_called_with(metrics)
+        scheduler2.step.assert_called_with(metrics)
diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py
index bd537b8680e2d0e4a8f72ddfc43f680cd5dbc0e1..c9cd6e1fe6f6209ccc95374e0d8fc049ee93e738 100644
--- a/tests/test_optim/test_scheduler/test_lr_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py
@@ -8,7 +8,8 @@ import torch.optim as optim
 
 from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
                                       CosineRestartLR, ExponentialLR, LinearLR,
-                                      MultiStepLR, OneCycleLR, PolyLR, StepLR,
+                                      MultiStepLR, OneCycleLR, PolyLR,
+                                      ReduceOnPlateauLR, StepLR,
                                       _ParamScheduler)
 from mmengine.testing import assert_allclose
 
@@ -195,9 +196,16 @@ class TestLRScheduler(TestCase):
                               schedulers,
                               targets,
                               epochs=10,
-                              param_name='lr'):
+                              param_name='lr',
+                              step_kwargs=None):
         if isinstance(schedulers, _ParamScheduler):
             schedulers = [schedulers]
+        if step_kwargs is None:
+            step_kwarg = [{} for _ in range(len(schedulers))]
+            step_kwargs = [step_kwarg for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
+            assert len(step_kwargs[0]) == len(schedulers)
         for epoch in range(epochs):
             for param_group, target in zip(self.optimizer.param_groups,
                                            targets):
@@ -209,7 +217,10 @@ class TestLRScheduler(TestCase):
                         param_group[param_name]),
                     atol=1e-5,
                     rtol=0)
-            [scheduler.step() for scheduler in schedulers]
+            [
+                scheduler.step(**step_kwargs[epoch][i])
+                for i, scheduler in enumerate(schedulers)
+            ]
 
     def test_step_scheduler(self):
         # lr = 0.05     if epoch < 3
@@ -361,11 +372,176 @@ class TestLRScheduler(TestCase):
             eta_min=0)
         self._test_scheduler_value(scheduler, targets, epochs=10)
 
-    def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
+    def test_reduce_on_plateau_scheduler(self):
+        # inherit _ParamScheduler but not call super().__init__(),
+        # so some codes need to be retested
+
+        # Test error in __init__ method
+        with self.assertRaises(TypeError):
+            ReduceOnPlateauLR('invalid_optimizer')
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, begin=10, end=5)
+        with self.assertRaises(AssertionError):
+            ReduceOnPlateauLR(self.optimizer, by_epoch=False)
+
+        for last_step in (1.5, -2):
+            with self.assertRaises(AssertionError):
+                ReduceOnPlateauLR(self.optimizer, last_step=last_step)
+
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, factor=2.0)
+        ReduceOnPlateauLR(self.optimizer, min_value=[0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, threshold=-1.0)
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, rule='foo')
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauLR(self.optimizer, threshold_rule='foo')
+
+        # Test error in step method
+        scheduler = ReduceOnPlateauLR(self.optimizer, monitor='loss')
+        assert scheduler.step() is None
+
+        with self.assertRaises(TypeError):
+            scheduler.step(('foo', 1.0))
+
+        metrics = dict(loss_foo=1.0)
+        with self.assertRaises(KeyError):
+            scheduler.step(metrics)
+
+        # Test scheduler value
+        def _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                        patience, threshold, threshold_rule, cooldown,
+                        min_value):
+            lr = 0.05
+            momentum = 0.01
+            weight_decay = 5e-4
+            scheduler = ReduceOnPlateauLR(
+                self.optimizer,
+                monitor=monitor,
+                rule=rule,
+                factor=factor,
+                patience=patience,
+                threshold=threshold,
+                threshold_rule=threshold_rule,
+                cooldown=cooldown,
+                min_value=min_value,
+            )
+            self._test_scheduler_value(
+                scheduler, targets, epochs=epochs, step_kwargs=metrics_list)
+
+            # reset the state of optimizers
+            self.optimizer = optim.SGD([{
+                'params': self.model.conv1.parameters()
+            }, {
+                'params': self.model.conv2.parameters(),
+                'lr': lr * self.layer2_mult,
+            }],
+                                       lr=lr,
+                                       momentum=momentum,
+                                       weight_decay=weight_decay)
+
+        epochs = 10
+        factor = 0.1
+        cooldown = 1
+        patience = 2
+
+        # rule(less) and threshold_rule(rel)
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(less) and threshold_rule(abs)
+        rule, threshold_rule = 'less', 'abs'
+        threshold = 0.9
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(rel)
+        rule, threshold_rule = 'greater', 'rel'
+        threshold = 0.01
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(abs)
+        rule, threshold_rule = 'greater', 'abs'
+        threshold = 0.9
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # change min_value
+        min_value = 0.01
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets_1 = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value,
+            min_value
+        ]
+        single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05]
+        targets = [single_targets_1, single_targets_2]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, min_value)
+
+    def _check_scheduler_state_dict(self,
+                                    construct,
+                                    construct2,
+                                    epochs=10,
+                                    step_kwargs=None):
+        if step_kwargs is None:
+            step_kwargs = [{} for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
         scheduler = construct()
-        for _ in range(epochs):
+        for epoch in range(epochs):
             scheduler.optimizer.step()
-            scheduler.step()
+            scheduler.step(**step_kwargs[epoch])
         scheduler_copy = construct2()
         scheduler_copy.load_state_dict(scheduler.state_dict())
         for key in scheduler.__dict__.keys():
@@ -429,6 +605,35 @@ class TestLRScheduler(TestCase):
                 eta_min=0),
             epochs=10)
 
+    def test_reduce_on_plateau_scheduler_state_dict(self):
+        epochs = 10
+        metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)]
+        self._check_scheduler_state_dict(
+            lambda: ReduceOnPlateauLR(
+                self.optimizer,
+                monitor='loss',
+                rule='less',
+                factor=0.01,
+                patience=5,
+                threshold=1e-4,
+                threshold_rule='rel',
+                cooldown=0,
+                min_value=0.0,
+                eps=1e-8),
+            lambda: ReduceOnPlateauLR(
+                self.optimizer,
+                monitor='loss_foo',
+                rule='greater',
+                factor=0.05,
+                patience=10,
+                threshold=1e-5,
+                threshold_rule='abs',
+                cooldown=5,
+                min_value=0.1,
+                eps=1e-9),
+            epochs=epochs,
+            step_kwargs=metrics_list)
+
     def test_step_scheduler_convert_iterbased(self):
         # invalid epoch_length
         with self.assertRaises(AssertionError):
diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
index d98f2fe424a2f915eb9c200b6b0d423e67bab15c..942259d7dad4d7a1ef56ba90f7bd9fdeb6e85c13 100644
--- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
@@ -6,12 +6,15 @@ import torch
 import torch.nn.functional as F
 import torch.optim as optim
 
+# yapf: disable
 from mmengine.optim.scheduler import (ConstantMomentum,
                                       CosineAnnealingMomentum,
                                       CosineRestartMomentum,
                                       ExponentialMomentum, LinearMomentum,
                                       MultiStepMomentum, PolyMomentum,
-                                      StepMomentum, _ParamScheduler)
+                                      ReduceOnPlateauMomentum, StepMomentum,
+                                      _ParamScheduler)
+# yapf: enable
 from mmengine.testing import assert_allclose
 
 
@@ -213,9 +216,16 @@ class TestMomentumScheduler(TestCase):
                               schedulers,
                               targets,
                               epochs=10,
-                              param_name='momentum'):
+                              param_name='momentum',
+                              step_kwargs=None):
         if isinstance(schedulers, _ParamScheduler):
             schedulers = [schedulers]
+        if step_kwargs is None:
+            step_kwarg = [{} for _ in range(len(schedulers))]
+            step_kwargs = [step_kwarg for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
+            assert len(step_kwargs[0]) == len(schedulers)
         for epoch in range(epochs):
             for param_group, target in zip(optimizer.param_groups, targets):
                 assert_allclose(
@@ -235,7 +245,10 @@ class TestMomentumScheduler(TestCase):
                                param_group['betas'][0]),
                         atol=1e-5,
                         rtol=0)
-            [scheduler.step() for scheduler in schedulers]
+            [
+                scheduler.step(**step_kwargs[epoch][i])
+                for i, scheduler in enumerate(schedulers)
+            ]
 
     def test_step_scheduler(self):
         # momentum = 0.05     if epoch < 3
@@ -437,11 +450,204 @@ class TestMomentumScheduler(TestCase):
         self._test_scheduler_value(
             self.optimizer_with_betas, scheduler, targets, epochs=10)
 
-    def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
+    def test_reduce_on_plateau_scheduler(self):
+        # inherit _ParamScheduler but not call super().__init__(),
+        # so some codes need to be retested
+
+        # Test error in __init__ method
+        with self.assertRaises(ValueError):
+            optimizer = optim.ASGD(
+                self.model.parameters(),
+                lr=0.01,
+            )
+            ReduceOnPlateauMomentum(optimizer)
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(self.optimizer, begin=10, end=5)
+        with self.assertRaises(AssertionError):
+            ReduceOnPlateauMomentum(self.optimizer, by_epoch=False)
+
+        for last_step in (1.5, -2):
+            with self.assertRaises(AssertionError):
+                ReduceOnPlateauMomentum(self.optimizer, last_step=last_step)
+
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(self.optimizer, factor=2.0)
+        ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(
+                self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(self.optimizer, threshold=-1.0)
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(self.optimizer, rule='foo')
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauMomentum(self.optimizer, threshold_rule='foo')
+
+        # Test error in step method
+        scheduler = ReduceOnPlateauMomentum(self.optimizer, monitor='loss')
+        assert scheduler.step() is None
+
+        with self.assertRaises(TypeError):
+            scheduler.step(('foo', 1.0))
+
+        metrics = dict(loss_foo=1.0)
+        with self.assertRaises(KeyError):
+            scheduler.step(metrics)
+
+        # Test scheduler value
+        def _test_value(epochs, targets, metrics_list, optimizer, monitor,
+                        rule, factor, patience, threshold, threshold_rule,
+                        cooldown, min_value):
+            lr = 0.01
+            momentum = 0.05
+            weight_decay = 5e-4
+            scheduler = ReduceOnPlateauMomentum(
+                optimizer,
+                monitor=monitor,
+                rule=rule,
+                factor=factor,
+                patience=patience,
+                threshold=threshold,
+                threshold_rule=threshold_rule,
+                cooldown=cooldown,
+                min_value=min_value,
+            )
+            self._test_scheduler_value(
+                optimizer,
+                scheduler,
+                targets,
+                epochs=epochs,
+                step_kwargs=metrics_list)
+
+            # reset the state of optimizers
+            self.optimizer = optim.SGD([{
+                'params': self.model.conv1.parameters()
+            }, {
+                'params': self.model.conv2.parameters(),
+                'momentum': momentum * self.layer2_mult
+            }],
+                                       lr=lr,
+                                       momentum=momentum,
+                                       weight_decay=weight_decay)
+            self.optimizer_with_betas = optim.Adam(
+                [{
+                    'params': self.model.conv1.parameters()
+                }, {
+                    'params': self.model.conv2.parameters(),
+                    'betas': (momentum * self.layer2_mult, 0.999)
+                }],
+                lr=lr,
+                betas=(momentum, 0.999),
+                weight_decay=weight_decay)
+
+        epochs = 10
+        factor = 0.1
+        cooldown = 1
+        patience = 2
+
+        # rule(less) and threshold_rule(rel)
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, self.optimizer, monitor,
+                    rule, factor, patience, threshold, threshold_rule,
+                    cooldown, 0.0)
+
+        # rule(less) and threshold_rule(abs)
+        rule, threshold_rule = 'less', 'abs'
+        threshold = 0.9
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, self.optimizer, monitor,
+                    rule, factor, patience, threshold, threshold_rule,
+                    cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(rel)
+        rule, threshold_rule = 'greater', 'rel'
+        threshold = 0.01
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, self.optimizer, monitor,
+                    rule, factor, patience, threshold, threshold_rule,
+                    cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(abs)
+        rule, threshold_rule = 'greater', 'abs'
+        threshold = 0.9
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, self.optimizer, monitor,
+                    rule, factor, patience, threshold, threshold_rule,
+                    cooldown, 0.0)
+
+        # change min_value
+        min_value = 0.01
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets_1 = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value,
+            min_value
+        ]
+        single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05]
+        targets = [single_targets_1, single_targets_2]
+
+        _test_value(epochs, targets, metrics_list, self.optimizer, monitor,
+                    rule, factor, patience, threshold, threshold_rule,
+                    cooldown, min_value)
+
+        _test_value(epochs, targets, metrics_list, self.optimizer_with_betas,
+                    monitor, rule, factor, patience, threshold, threshold_rule,
+                    cooldown, min_value)
+
+    def _check_scheduler_state_dict(self,
+                                    construct,
+                                    construct2,
+                                    epochs=10,
+                                    step_kwargs=None):
+        if step_kwargs is None:
+            step_kwargs = [{} for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
         scheduler = construct()
-        for _ in range(epochs):
+        for epoch in range(epochs):
             scheduler.optimizer.step()
-            scheduler.step()
+            scheduler.step(**step_kwargs[epoch])
         scheduler_copy = construct2()
         scheduler_copy.load_state_dict(scheduler.state_dict())
         for key in scheduler.__dict__.keys():
@@ -506,6 +712,35 @@ class TestMomentumScheduler(TestCase):
                 eta_min=0),
             epochs=10)
 
+    def test_reduce_on_plateau_scheduler_state_dict(self):
+        epochs = 10
+        metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)]
+        self._check_scheduler_state_dict(
+            lambda: ReduceOnPlateauMomentum(
+                self.optimizer,
+                monitor='loss',
+                rule='less',
+                factor=0.01,
+                patience=5,
+                threshold=1e-4,
+                threshold_rule='rel',
+                cooldown=0,
+                min_value=0.0,
+                eps=1e-8),
+            lambda: ReduceOnPlateauMomentum(
+                self.optimizer,
+                monitor='loss_foo',
+                rule='greater',
+                factor=0.05,
+                patience=10,
+                threshold=1e-5,
+                threshold_rule='abs',
+                cooldown=5,
+                min_value=0.1,
+                eps=1e-9),
+            epochs=epochs,
+            step_kwargs=metrics_list)
+
     def test_multi_scheduler_without_overlap_linear_multi_step(self):
         # use Linear in the first 5 epochs and then use MultiStep
         epochs = 12
diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py
index ce86195b944fc9c696bdb30c596d718915ee5180..b210983143e5086f8074aa135f63d8c4b061b6ab 100644
--- a/tests/test_optim/test_scheduler/test_param_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_param_scheduler.py
@@ -17,8 +17,9 @@ from mmengine.optim.scheduler import (ConstantParamScheduler,
                                       LinearParamScheduler,
                                       MultiStepParamScheduler,
                                       OneCycleParamScheduler,
-                                      PolyParamScheduler, StepParamScheduler,
-                                      _ParamScheduler)
+                                      PolyParamScheduler,
+                                      ReduceOnPlateauParamScheduler,
+                                      StepParamScheduler, _ParamScheduler)
 # yapf: enable
 from mmengine.testing import assert_allclose
 
@@ -239,9 +240,16 @@ class TestParameterScheduler(TestCase):
                               schedulers,
                               targets,
                               epochs=10,
-                              param_name='lr'):
+                              param_name='lr',
+                              step_kwargs=None):
         if isinstance(schedulers, _ParamScheduler):
             schedulers = [schedulers]
+        if step_kwargs is None:
+            step_kwarg = [{} for _ in range(len(schedulers))]
+            step_kwargs = [step_kwarg for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
+            assert len(step_kwargs[0]) == len(schedulers)
         for epoch in range(epochs):
             for param_group, target in zip(self.optimizer.param_groups,
                                            targets):
@@ -253,7 +261,10 @@ class TestParameterScheduler(TestCase):
                         param_group[param_name]),
                     atol=1e-5,
                     rtol=0)
-            [scheduler.step() for scheduler in schedulers]
+            [
+                scheduler.step(**step_kwargs[epoch][i])
+                for i, scheduler in enumerate(schedulers)
+            ]
 
     def test_step_scheduler(self):
         # lr = 0.05     if epoch < 3
@@ -488,11 +499,186 @@ class TestParameterScheduler(TestCase):
             eta_min=eta_min)
         self._test_scheduler_value(scheduler, targets, epochs=10)
 
-    def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
+    def test_reduce_on_plateau_scheduler(self):
+        # inherit _ParamScheduler but not call super().__init__(),
+        # so some codes need to be retested
+
+        # Test error in __init__ method
+        with self.assertRaises(TypeError):
+            ReduceOnPlateauParamScheduler('invalid_optimizer', param_name='lr')
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(
+                self.optimizer, 'lr', begin=10, end=5)
+        with self.assertRaises(AssertionError):
+            ReduceOnPlateauParamScheduler(self.optimizer, 'lr', by_epoch=False)
+
+        for last_step in (1.5, -2):
+            with self.assertRaises(AssertionError):
+                ReduceOnPlateauParamScheduler(
+                    self.optimizer, 'lr', last_step=last_step)
+
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(self.optimizer, 'lr', factor=2.0)
+        ReduceOnPlateauParamScheduler(
+            self.optimizer, 'lr', min_value=[0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(
+                self.optimizer, 'lr', min_value=[0.1, 0.1, 0.1, 0.1])
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(self.optimizer, 'lr', threshold=-1.0)
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(self.optimizer, 'lr', rule='foo')
+        with self.assertRaises(ValueError):
+            ReduceOnPlateauParamScheduler(
+                self.optimizer, 'lr', threshold_rule='foo')
+
+        # Test error in step method
+        scheduler = ReduceOnPlateauParamScheduler(
+            self.optimizer, param_name='lr', monitor='loss')
+        assert scheduler.step() is None
+
+        with self.assertRaises(TypeError):
+            scheduler.step(('foo', 1.0))
+
+        metrics = dict(loss_foo=1.0)
+        with self.assertRaises(KeyError):
+            scheduler.step(metrics)
+
+        # Test scheduler value
+        def _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                        patience, threshold, threshold_rule, cooldown,
+                        min_value):
+            lr = 0.05
+            momentum = 0.01
+            weight_decay = 5e-4
+            scheduler = ReduceOnPlateauParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                monitor=monitor,
+                rule=rule,
+                factor=factor,
+                patience=patience,
+                threshold=threshold,
+                threshold_rule=threshold_rule,
+                cooldown=cooldown,
+                min_value=min_value,
+            )
+            self._test_scheduler_value(
+                scheduler, targets, epochs=epochs, step_kwargs=metrics_list)
+
+            # reset the state of optimizers
+            self.optimizer = optim.SGD(
+                [{
+                    'params': self.model.conv1.parameters()
+                }, {
+                    'params': self.model.conv2.parameters(),
+                    'lr': lr * self.layer2_mult,
+                    'momentum': momentum * self.layer2_mult,
+                    'weight_decay': weight_decay * self.layer2_mult
+                }],
+                lr=lr,
+                momentum=momentum,
+                weight_decay=weight_decay)
+
+        epochs = 10
+        factor = 0.1
+        cooldown = 1
+        patience = 2
+
+        # rule(less) and threshold_rule(rel)
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(less) and threshold_rule(abs)
+        rule, threshold_rule = 'less', 'abs'
+        threshold = 0.9
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(rel)
+        rule, threshold_rule = 'greater', 'rel'
+        threshold = 0.01
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # rule(greater) and threshold_rule(abs)
+        rule, threshold_rule = 'greater', 'abs'
+        threshold = 0.9
+        monitor = 'bbox_mAP'
+        metric_values = [1., 2., 3., 4., 5., 5., 5., 5., 5., 5.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.005, 0.005
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, 0.0)
+
+        # change min_value
+        min_value = 0.01
+        rule, threshold_rule = 'less', 'rel'
+        threshold = 0.01
+        monitor = 'loss'
+        metric_values = [10., 9., 8., 7., 6., 6., 6., 6., 6., 6.]
+        metrics_list = [[dict(metrics={monitor: v})] for v in metric_values]
+        single_targets_1 = [
+            0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, min_value,
+            min_value
+        ]
+        single_targets_2 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.05, 0.05]
+        targets = [single_targets_1, single_targets_2]
+
+        _test_value(epochs, targets, metrics_list, monitor, rule, factor,
+                    patience, threshold, threshold_rule, cooldown, min_value)
+
+    def _check_scheduler_state_dict(self,
+                                    construct,
+                                    construct2,
+                                    epochs=10,
+                                    step_kwargs=None):
+        if step_kwargs is None:
+            step_kwargs = [{} for _ in range(epochs)]
+        else:  # step_kwargs is not None
+            assert len(step_kwargs) == epochs
         scheduler = construct()
-        for _ in range(epochs):
+        for epoch in range(epochs):
             scheduler.optimizer.step()
-            scheduler.step()
+            scheduler.step(**step_kwargs[epoch])
         scheduler_copy = construct2()
         torch.save(scheduler.state_dict(),
                    osp.join(self.temp_dir.name, 'tmp.pth'))
@@ -581,6 +767,37 @@ class TestParameterScheduler(TestCase):
                 eta_min=0),
             epochs=10)
 
+    def test_reduce_on_plateau_scheduler_state_dict(self):
+        epochs = 10
+        metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)]
+        self._check_scheduler_state_dict(
+            lambda: ReduceOnPlateauParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                monitor='loss',
+                rule='less',
+                factor=0.01,
+                patience=5,
+                threshold=1e-4,
+                threshold_rule='rel',
+                cooldown=0,
+                min_value=0.0,
+                eps=1e-8),
+            lambda: ReduceOnPlateauParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                monitor='loss_foo',
+                rule='greater',
+                factor=0.05,
+                patience=10,
+                threshold=1e-5,
+                threshold_rule='abs',
+                cooldown=5,
+                min_value=0.1,
+                eps=1e-9),
+            epochs=epochs,
+            step_kwargs=metrics_list)
+
     def test_step_scheduler_convert_iterbased(self):
         # invalid epoch_length
         with self.assertRaises(AssertionError):