diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py
index df54f738ae3b5134b10d9809fc64dcfa4c9bb33e..c12d1c2970d39561e3ddec2e92e49f67bbfdc4ba 100644
--- a/mmengine/optim/scheduler/__init__.py
+++ b/mmengine/optim/scheduler/__init__.py
@@ -1,16 +1,22 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
-                           LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
+# yapf: disable
+from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
+                           ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
+                           PolyLR, StepLR)
 from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
-                                 ExponentialMomentum, LinearMomentum,
-                                 MultiStepMomentum, PolyMomentum, StepMomentum)
+                                 CosineRestartMomentum, ExponentialMomentum,
+                                 LinearMomentum, MultiStepMomentum,
+                                 PolyMomentum, StepMomentum)
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
+                              CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, OneCycleParamScheduler,
                               PolyParamScheduler, StepParamScheduler,
                               _ParamScheduler)
 
+# yapf: enable
+
 __all__ = [
     'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
     'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
@@ -19,5 +25,6 @@ __all__ = [
     'ExponentialParamScheduler', 'LinearParamScheduler',
     'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
     'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
-    'OneCycleLR'
+    'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
+    'CosineRestartMomentum'
 ]
diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py
index 21ae823e04a4a90eb5aa22c4d385b04b92ab81ed..cbb82dfc550ae933e5320a7db710ffca7fb72b3c 100644
--- a/mmengine/optim/scheduler/lr_scheduler.py
+++ b/mmengine/optim/scheduler/lr_scheduler.py
@@ -2,6 +2,7 @@
 from mmengine.registry import PARAM_SCHEDULERS
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
+                              CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, OneCycleParamScheduler,
                               PolyParamScheduler, StepParamScheduler)
@@ -277,3 +278,35 @@ class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
     .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
         https://arxiv.org/abs/1708.07120
     """# noqa E501
+
+
+@PARAM_SCHEDULERS.register_module()
+class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
+    """Sets the learning rate of each parameter group according to the cosine
+    annealing with restarts scheme. The cosine restart policy anneals the
+    learning rate from the initial value to `eta_min` with a cosine annealing
+    schedule and then restarts another period from the maximum value multiplied
+    with `restart_weight`.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        periods (list[int]): Periods for each cosine anneling cycle.
+        restart_weights (list[float]): Restart weights at each
+            restart iteration. Defaults to [1].
+        eta_min (float): Minimum parameter value at the end of scheduling.
+            Defaults to None.
+        eta_min_ratio (float, optional): The ratio of minimum parameter value
+            to the base parameter value. Either `min_lr` or `min_lr_ratio`
+            should be specified. Default: None.
+        begin (int): Step at which to start updating the parameters.
+            Defaults to 0.
+        end (int): Step at which to stop updating the parameters.
+            Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py
index 2cbe9d707914d7b53861c7bebbfdb2c1d62e74dc..9cb5f9e914cecd29ac595585c3672ddaedc25a6a 100644
--- a/mmengine/optim/scheduler/momentum_scheduler.py
+++ b/mmengine/optim/scheduler/momentum_scheduler.py
@@ -2,6 +2,7 @@
 from mmengine.registry import PARAM_SCHEDULERS
 from .param_scheduler import (ConstantParamScheduler,
                               CosineAnnealingParamScheduler,
+                              CosineRestartParamScheduler,
                               ExponentialParamScheduler, LinearParamScheduler,
                               MultiStepParamScheduler, PolyParamScheduler,
                               StepParamScheduler)
@@ -243,3 +244,36 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
         verbose (bool): Whether to print the value for each update.
             Defaults to False.
     """
+
+
+@PARAM_SCHEDULERS.register_module()
+class CosineRestartMomentum(MomentumSchedulerMixin,
+                            CosineRestartParamScheduler):
+    """Sets the momentum of each parameter group according to the cosine
+    annealing with restarts scheme. The cosine restart policy anneals the
+    momentum from the initial value to `eta_min` with a cosine annealing
+    schedule and then restarts another period from the maximum value multiplied
+    with `restart_weight`.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        periods (list[int]): Periods for each cosine anneling cycle.
+        restart_weights (list[float]): Restart weights at each
+            restart iteration. Defaults to [1].
+        eta_min (float): Minimum parameter value at the end of scheduling.
+            Defaults to None.
+        eta_min_ratio (float, optional): The ratio of minimum parameter value
+            to the base parameter value. Either `min_lr` or `min_lr_ratio`
+            should be specified. Default: None.
+        begin (int): Step at which to start updating the parameters.
+            Defaults to 0.
+        end (int): Step at which to stop updating the parameters.
+            Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py
index 50ea25a12e058c0ffa8051bee326d8b2280912a3..519e7b35b66a5cbc771482d70f0cd09252f7c223 100644
--- a/mmengine/optim/scheduler/param_scheduler.py
+++ b/mmengine/optim/scheduler/param_scheduler.py
@@ -9,7 +9,7 @@ import warnings
 import weakref
 from collections import Counter
 from functools import wraps
-from typing import Callable, List, Optional, Union
+from typing import Callable, List, Optional, Sequence, Union
 
 from torch.optim import Optimizer
 
@@ -227,6 +227,8 @@ class StepParamScheduler(_ParamScheduler):
 
     Args:
         optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         step_size (int): Period of parameter value decay.
         gamma (float): Multiplicative factor of parameter value decay.
             Defaults to 0.1.
@@ -313,6 +315,8 @@ class MultiStepParamScheduler(_ParamScheduler):
 
     Args:
         optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         milestones (list): List of epoch indices. Must be increasing.
         gamma (float): Multiplicative factor of parameter value decay.
             Defaults to 0.1.
@@ -401,6 +405,8 @@ class ConstantParamScheduler(_ParamScheduler):
     Args:
         optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
             optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         factor (float): The number we multiply parameter value until the
             milestone. Defaults to 1./3.
         begin (int): Step at which to start updating the parameters.
@@ -488,6 +494,8 @@ class ExponentialParamScheduler(_ParamScheduler):
     Args:
         optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
             optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         gamma (float): Multiplicative factor of parameter value decay.
         begin (int): Step at which to start updating the parameters.
             Defaults to 0.
@@ -585,6 +593,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
     Args:
         optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
             optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         T_max (int, optional): Maximum number of iterations. If not specified,
             use ``end - begin``. Defaults to None.
         eta_min (float): Minimum parameter value. Defaults to 0.
@@ -684,6 +694,8 @@ class LinearParamScheduler(_ParamScheduler):
     Args:
         optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
             optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         start_factor (float): The number we multiply parameter value in the
             first epoch. The multiplication factor changes towards end_factor
             in the following epochs. Defaults to 1./3.
@@ -780,6 +792,8 @@ class PolyParamScheduler(_ParamScheduler):
     Args:
         optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
             optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         eta_min (float): Minimum parameter value at the end of scheduling.
             Defaults to 0.
         power (float): The power of the polynomial. Defaults to 1.0.
@@ -882,6 +896,8 @@ class OneCycleParamScheduler(_ParamScheduler):
 
     Args:
         optimizer (Optimizer): Wrapped optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
         eta_max (float or list): Upper parameter value boundaries in the cycle
             for each parameter group.
         total_steps (int): The total number of steps in the cycle. Note that
@@ -1094,3 +1110,159 @@ class OneCycleParamScheduler(_ParamScheduler):
             params.append(computed_param)
 
         return params
+
+
+@PARAM_SCHEDULERS.register_module()
+class CosineRestartParamScheduler(_ParamScheduler):
+    """Sets the parameters of each parameter group according to the cosine
+    annealing with restarts scheme. The cosine restart policy anneals the
+    parameter from the initial value to `eta_min` with a cosine annealing
+    schedule and then restarts another period from the maximum value multiplied
+    with `restart_weight`.
+
+    Args:
+        optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
+            optimizer.
+        param_name (str): Name of the parameter to be adjusted, such as
+            ``lr``, ``momentum``.
+        periods (list[int]): Periods for each cosine anneling cycle.
+        restart_weights (list[float]): Restart weights at each
+            restart iteration. Defaults to [1].
+        eta_min (float): Minimum parameter value at the end of scheduling.
+            Defaults to None.
+        eta_min_ratio (float, optional): The ratio of minimum parameter value
+            to the base parameter value. Either `min_lr` or `min_lr_ratio`
+            should be specified. Default: None.
+        begin (int): Step at which to start updating the parameters.
+            Defaults to 0.
+        end (int): Step at which to stop updating the parameters.
+            Defaults to INF.
+        last_step (int): The index of last step. Used for resume without
+            state dict. Defaults to -1.
+        by_epoch (bool): Whether the scheduled parameters are updated by
+            epochs. Defaults to True.
+        verbose (bool): Whether to print the value for each update.
+            Defaults to False.
+    """
+
+    def __init__(self,
+                 optimizer: Union[Optimizer, OptimWrapper],
+                 param_name: str,
+                 periods: List[int],
+                 restart_weights: Sequence[float] = (1, ),
+                 eta_min: Optional[float] = None,
+                 eta_min_ratio: Optional[float] = None,
+                 begin: int = 0,
+                 end: int = INF,
+                 last_step: int = -1,
+                 by_epoch: bool = True,
+                 verbose: bool = False):
+        assert (eta_min is None) ^ (eta_min_ratio is None)
+        self.periods = periods
+        self.eta_min = eta_min
+        self.eta_min_ratio = eta_min_ratio
+        self.restart_weights = restart_weights
+        assert (len(self.periods) == len(self.restart_weights)
+                ), 'periods and restart_weights should have the same length.'
+        self.cumulative_periods = [
+            sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+        ]
+
+        super().__init__(
+            optimizer,
+            param_name=param_name,
+            begin=begin,
+            end=end,
+            last_step=last_step,
+            by_epoch=by_epoch,
+            verbose=verbose)
+
+    @classmethod
+    def build_iter_from_epoch(cls,
+                              *args,
+                              periods,
+                              begin=0,
+                              end=INF,
+                              by_epoch=True,
+                              epoch_length=None,
+                              **kwargs):
+        """Build an iter-based instance of this scheduler from an epoch-based
+        config."""
+        assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
+                         'be converted to iter-based.'
+        assert epoch_length is not None and epoch_length > 0, \
+            f'`epoch_length` must be a positive integer, ' \
+            f'but got {epoch_length}.'
+        periods = [p * epoch_length for p in periods]
+        by_epoch = False
+        begin = int(begin * epoch_length)
+        if end != INF:
+            end = int(end * epoch_length)
+        return cls(
+            *args,
+            periods=periods,
+            begin=begin,
+            end=end,
+            by_epoch=by_epoch,
+            **kwargs)
+
+    def _get_value(self):
+        """Compute value using chainable form of the scheduler."""
+        idx = self.get_position_from_periods(self.last_step,
+                                             self.cumulative_periods)
+        # if current step is not in the periods, return origin parameters
+        if idx is None:
+            return [
+                group[self.param_name] for group in self.optimizer.param_groups
+            ]
+        current_weight = self.restart_weights[idx]
+        nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+        current_periods = self.periods[idx]
+        step = self.last_step - nearest_restart
+        values = []
+        for base_value, group in zip(self.base_values,
+                                     self.optimizer.param_groups):
+            eta_max = base_value * current_weight
+            if self.eta_min_ratio is None:
+                eta_min = self.eta_min * (1 - current_weight)
+            else:
+                eta_min = base_value * self.eta_min_ratio * (1 -
+                                                             current_weight)
+            if step == 0:
+                values.append(eta_max)
+
+            elif (step - 1 - current_periods) % (2 * current_periods) == 0:
+                values.append(group[self.param_name] + (eta_max - eta_min) *
+                              (1 - math.cos(math.pi / current_periods)) / 2)
+            else:
+                values.append(
+                    (1 + math.cos(math.pi * step / current_periods)) /
+                    (1 + math.cos(math.pi * (step - 1) / current_periods)) *
+                    (group[self.param_name] - eta_min) + eta_min)
+
+        return values
+
+    @staticmethod
+    def get_position_from_periods(
+            iteration: int, cumulative_periods: List[int]) -> Optional[int]:
+        """Get the position from a period list.
+
+        It will return the index of the right-closest number in the period
+        list.
+        For example, the cumulative_periods = [100, 200, 300, 400],
+        if iteration == 50, return 0;
+        if iteration == 210, return 2;
+        if iteration == 300, return 3.
+
+        Args:
+            iteration (int): Current iteration.
+            cumulative_periods (list[int]): Cumulative period list.
+
+        Returns:
+            Optional[int]: The position of the right-closest number in the
+            period list. If not in the period, return None.
+        """
+        for i, period in enumerate(cumulative_periods):
+            if iteration < period:
+                return i
+        return None
diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py
index 87f83479537cb16fb3b2f5d2b10c4a9c3a97a539..bd537b8680e2d0e4a8f72ddfc43f680cd5dbc0e1 100644
--- a/tests/test_optim/test_scheduler/test_lr_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py
@@ -7,8 +7,8 @@ import torch.nn.functional as F
 import torch.optim as optim
 
 from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
-                                      ExponentialLR, LinearLR, MultiStepLR,
-                                      OneCycleLR, PolyLR, StepLR,
+                                      CosineRestartLR, ExponentialLR, LinearLR,
+                                      MultiStepLR, OneCycleLR, PolyLR, StepLR,
                                       _ParamScheduler)
 from mmengine.testing import assert_allclose
 
@@ -333,6 +333,34 @@ class TestLRScheduler(TestCase):
             self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
         self._test_scheduler_value(scheduler, targets, epochs=10)
 
+    def test_cosine_restart_scheduler(self):
+        with self.assertRaises(AssertionError):
+            CosineRestartLR(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0,
+                eta_min_ratio=0.1)
+        with self.assertRaises(AssertionError):
+            CosineRestartLR(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5, 0.0],
+                eta_min=0)
+        single_targets = [
+            0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
+            0.0086372, 0.0023872, 0.0023872
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+        scheduler = CosineRestartLR(
+            self.optimizer,
+            periods=[4, 5],
+            restart_weights=[1, 0.5],
+            eta_min=0)
+        self._test_scheduler_value(scheduler, targets, epochs=10)
+
     def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
         scheduler = construct()
         for _ in range(epochs):
@@ -387,6 +415,20 @@ class TestLRScheduler(TestCase):
             lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
             epochs=10)
 
+    def test_cosine_restart_scheduler_state_dict(self):
+        self._check_scheduler_state_dict(
+            lambda: CosineRestartLR(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            lambda: CosineRestartLR(
+                self.optimizer,
+                periods=[4, 6],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            epochs=10)
+
     def test_step_scheduler_convert_iterbased(self):
         # invalid epoch_length
         with self.assertRaises(AssertionError):
diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
index 7014205e0b0238d4fe3e7d7ce1e6b1a4669bf636..d98f2fe424a2f915eb9c200b6b0d423e67bab15c 100644
--- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
@@ -8,6 +8,7 @@ import torch.optim as optim
 
 from mmengine.optim.scheduler import (ConstantMomentum,
                                       CosineAnnealingMomentum,
+                                      CosineRestartMomentum,
                                       ExponentialMomentum, LinearMomentum,
                                       MultiStepMomentum, PolyMomentum,
                                       StepMomentum, _ParamScheduler)
@@ -399,6 +400,43 @@ class TestMomentumScheduler(TestCase):
         self._test_scheduler_value(
             self.optimizer_with_betas, scheduler, targets, epochs=10)
 
+    def test_cosine_restart_scheduler(self):
+        with self.assertRaises(AssertionError):
+            CosineRestartMomentum(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0,
+                eta_min_ratio=0.1)
+        with self.assertRaises(AssertionError):
+            CosineRestartMomentum(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5, 0.0],
+                eta_min=0)
+        single_targets = [
+            0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
+            0.0086372, 0.0023872, 0.0023872
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+        scheduler = CosineRestartMomentum(
+            self.optimizer,
+            periods=[4, 5],
+            restart_weights=[1, 0.5],
+            eta_min=0)
+        self._test_scheduler_value(
+            self.optimizer, scheduler, targets, epochs=10)
+
+        scheduler = CosineRestartMomentum(
+            self.optimizer_with_betas,
+            periods=[4, 5],
+            restart_weights=[1, 0.5],
+            eta_min=0)
+        self._test_scheduler_value(
+            self.optimizer_with_betas, scheduler, targets, epochs=10)
+
     def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
         scheduler = construct()
         for _ in range(epochs):
@@ -454,6 +492,20 @@ class TestMomentumScheduler(TestCase):
             lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
             epochs=10)
 
+    def test_cosine_restart_scheduler_state_dict(self):
+        self._check_scheduler_state_dict(
+            lambda: CosineRestartMomentum(
+                self.optimizer,
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            lambda: CosineRestartMomentum(
+                self.optimizer,
+                periods=[4, 6],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            epochs=10)
+
     def test_multi_scheduler_without_overlap_linear_multi_step(self):
         # use Linear in the first 5 epochs and then use MultiStep
         epochs = 12
diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py
index 753bacb1281aebcdcd32dc5ded4dd7b8ba219500..e27f0dd7b1c77e0e758666d7024d739787ab8f38 100644
--- a/tests/test_optim/test_scheduler/test_param_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_param_scheduler.py
@@ -12,12 +12,13 @@ from mmengine.optim import OptimWrapper
 # yapf: disable
 from mmengine.optim.scheduler import (ConstantParamScheduler,
                                       CosineAnnealingParamScheduler,
+                                      CosineRestartParamScheduler,
                                       ExponentialParamScheduler,
                                       LinearParamScheduler,
                                       MultiStepParamScheduler,
+                                      OneCycleParamScheduler,
                                       PolyParamScheduler, StepParamScheduler,
                                       _ParamScheduler)
-from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
 # yapf: enable
 from mmengine.testing import assert_allclose
 
@@ -406,6 +407,37 @@ class TestParameterScheduler(TestCase):
             end=iters + 1)
         self._test_scheduler_value(scheduler, targets, epochs=10)
 
+    def test_cosine_restart_scheduler(self):
+        with self.assertRaises(AssertionError):
+            CosineRestartParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0,
+                eta_min_ratio=0.1)
+        with self.assertRaises(AssertionError):
+            CosineRestartParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                periods=[4, 5],
+                restart_weights=[1, 0.5, 0.0],
+                eta_min=0)
+        single_targets = [
+            0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
+            0.0086372, 0.0023872, 0.0023872
+        ]
+        targets = [
+            single_targets, [t * self.layer2_mult for t in single_targets]
+        ]
+        scheduler = CosineRestartParamScheduler(
+            self.optimizer,
+            param_name='lr',
+            periods=[4, 5],
+            restart_weights=[1, 0.5],
+            eta_min=0)
+        self._test_scheduler_value(scheduler, targets, epochs=10)
+
     def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
         scheduler = construct()
         for _ in range(epochs):
@@ -483,6 +515,22 @@ class TestParameterScheduler(TestCase):
                 self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
             epochs=10)
 
+    def test_cosine_restart_scheduler_state_dict(self):
+        self._check_scheduler_state_dict(
+            lambda: CosineRestartParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                periods=[4, 5],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            lambda: CosineRestartParamScheduler(
+                self.optimizer,
+                param_name='lr',
+                periods=[4, 6],
+                restart_weights=[1, 0.5],
+                eta_min=0),
+            epochs=10)
+
     def test_step_scheduler_convert_iterbased(self):
         # invalid epoch_length
         with self.assertRaises(AssertionError):