diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py
index 92afd96b08d09d33cd20a4bdc0ff719d51049510..4394744e8895bbb5f4315306bfdd8e0bdb1f078e 100644
--- a/mmengine/hooks/ema_hook.py
+++ b/mmengine/hooks/ema_hook.py
@@ -19,17 +19,44 @@ class EMAHook(Hook):
         - EMAHook takes priority over CheckpointHook.
         - The original model parameters are actually saved in ema field after
           train.
+        - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time.
 
     Args:
         ema_type (str): The type of EMA strategy to use. You can find the
-            supported strategies in ``mmengine.model.averaged_model``.
-            Defaults to 'ExponentialMovingAverage'
+            supported strategies in :mod:`mmengine.model.averaged_model`.
+            Defaults to 'ExponentialMovingAverage'.
+        strict_load (bool): Whether to strictly enforce that the keys of
+            ``state_dict`` in checkpoint match the keys returned by
+            ``self.module.state_dict``. Defaults to True.
+        begin_iter (int): The number of iteration to enable ``EMAHook``.
+            Defaults to 0.
+        begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults
+            to 0.
+        **kwargs: Keyword arguments passed to subclasses of
+            :obj:`BaseAveragedModel`
     """
 
     priority = 'NORMAL'
 
-    def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
+    def __init__(self,
+                 ema_type: str = 'ExponentialMovingAverage',
+                 strict_load: bool = True,
+                 begin_iter: int = 0,
+                 begin_epoch: int = 0,
+                 **kwargs):
+        self.strict_load = strict_load
         self.ema_cfg = dict(type=ema_type, **kwargs)
+        assert not (begin_iter != 0 and begin_epoch != 0), (
+            '`begin_iter` and `begin_epoch` should not be both set.')
+        assert begin_iter >= 0, (
+            f'begin_iter must larger than 0, but got begin: {begin_iter}')
+        assert begin_epoch >= 0, (
+            f'begin_epoch must larger than 0, but got begin: {begin_epoch}')
+        self.begin_iter = begin_iter
+        self.begin_epoch = begin_epoch
+        # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be
+        # enabled at 0 iteration.
+        self.enabled_by_epoch = self.begin_epoch > 0
 
     def before_run(self, runner) -> None:
         """Create an ema copy of the model."""
@@ -40,64 +67,81 @@ class EMAHook(Hook):
         self.ema_model = MODELS.build(
             self.ema_cfg, default_args=dict(model=self.src_model))
 
+        if self.enabled_by_epoch:
+            assert self.begin_epoch <= runner.max_epochs, (
+                'self.begin_epoch should be smaller than runner.max_epochs: '
+                f'{runner.max_epochs}, but got begin: {self.begin_epoch}')
+        else:
+            assert self.begin_iter <= runner.max_iters, (
+                'self.begin_iter should be smaller than runner.max_iters: '
+                f'{runner.max_iters}, but got begin: {self.begin_iter}')
+
     def after_train_iter(self,
                          runner,
                          batch_idx: int,
                          data_batch: DATA_BATCH = None,
                          outputs: Optional[dict] = None) -> None:
         """Update ema parameter."""
-        self.ema_model.update_parameters(self.src_model)
+        if self._ema_started(runner):
+            self.ema_model.update_parameters(self.src_model)
 
     def before_val_epoch(self, runner) -> None:
         """We load parameter values from ema model to source model before
         validation."""
-        self._swap_ema_parameters()
+        if self._ema_started(runner):
+            self._swap_ema_parameters()
 
     def after_val_epoch(self,
                         runner,
                         metrics: Optional[Dict[str, float]] = None) -> None:
         """We recover source model's parameter from ema model after
         validation."""
-        self._swap_ema_parameters()
+        if self._ema_started(runner):
+            self._swap_ema_parameters()
 
     def before_test_epoch(self, runner) -> None:
         """We load parameter values from ema model to source model before
         test."""
-        self._swap_ema_parameters()
+        if self._ema_started(runner):
+            self._swap_ema_parameters()
 
     def after_test_epoch(self,
                          runner,
                          metrics: Optional[Dict[str, float]] = None) -> None:
         """We recover source model's parameter from ema model after test."""
-        self._swap_ema_parameters()
+        if self._ema_started(runner):
+            self._swap_ema_parameters()
 
     def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
         """Save ema parameters to checkpoint."""
-        checkpoint['ema_state_dict'] = self.ema_model.state_dict()
-        # Save ema parameters to the source model's state dict so that we can
-        # directly load the averaged model weights for deployment.
-        # Swapping the state_dict key-values instead of swapping model
-        # parameters because the state_dict is a shallow copy of model
-        # parameters.
-        self._swap_ema_state_dict(checkpoint)
+        if self._ema_started(runner):
+            checkpoint['ema_state_dict'] = self.ema_model.state_dict()
+            # Save ema parameters to the source model's state dict so that we
+            # can directly load the averaged model weights for deployment.
+            # Swapping the state_dict key-values instead of swapping model
+            # parameters because the state_dict is a shallow copy of model
+            # parameters.
+            self._swap_ema_state_dict(checkpoint)
 
     def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
         """Resume ema parameters from checkpoint."""
-
-        if 'ema_state_dict' in checkpoint:
-            # The original model parameters are actually saved in ema field.
-            # swap the weights back to resume ema state.
-            self._swap_ema_state_dict(checkpoint)
-            self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
-
-        # Support load checkpoint without ema state dict.
-        else:
-            print_log(
-                'There is no `ema_state_dict` in checkpoint. '
-                '`EMAHook` will make a copy of `state_dict` as the '
-                'initial `ema_state_dict`', 'current', logging.WARNING)
-            self.ema_model.module.load_state_dict(
-                copy.deepcopy(checkpoint['state_dict']))
+        if self._ema_started(runner):
+            if 'ema_state_dict' in checkpoint:
+                # The original model parameters are actually saved in ema
+                # field swap the weights back to resume ema state.
+                self._swap_ema_state_dict(checkpoint)
+                self.ema_model.load_state_dict(
+                    checkpoint['ema_state_dict'], strict=self.strict_load)
+
+            # Support load checkpoint without ema state dict.
+            else:
+                print_log(
+                    'There is no `ema_state_dict` in checkpoint. '
+                    '`EMAHook` will make a copy of `state_dict` as the '
+                    'initial `ema_state_dict`', 'current', logging.WARNING)
+                self.ema_model.module.load_state_dict(
+                    copy.deepcopy(checkpoint['state_dict']),
+                    strict=self.strict_load)
 
     def _swap_ema_parameters(self) -> None:
         """Swap the parameter of model with ema_model."""
@@ -124,3 +168,21 @@ class EMAHook(Hook):
                 tmp = ema_state[k]
                 ema_state[k] = model_state[k[7:]]
                 model_state[k[7:]] = tmp
+
+    def _ema_started(self, runner) -> bool:
+        """Whether ``EMAHook`` has been initialized at current iteration or
+        epoch.
+
+        :attr:`ema_model` will be initialized when ``runner.iter`` or
+        ``runner.epoch`` is greater than ``self.begin`` for the first time.
+
+        Args:
+            runner (Runner): Runner of the training, validation process.
+
+        Returns:
+            bool: Whether ``EMAHook`` has been initialized.
+        """
+        if self.enabled_by_epoch:
+            return runner.epoch + 1 >= self.begin_epoch
+        else:
+            return runner.iter + 1 >= self.begin_iter
diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py
index 39e8ae52f99a31d35539900d41618c99ec968456..571125bc4e8bd1a08394ce5814e07874528baa53 100644
--- a/tests/test_hook/test_ema_hook.py
+++ b/tests/test_hook/test_ema_hook.py
@@ -43,6 +43,16 @@ class ToyModel1(BaseModel, ToyModel):
         return super(BaseModel, self).forward(*args, **kwargs)
 
 
+class ToyModel2(BaseModel, ToyModel):
+
+    def __init__(self):
+        super().__init__()
+        self.linear1 = nn.Linear(2, 1)
+
+    def forward(self, *args, **kwargs):
+        return super(BaseModel, self).forward(*args, **kwargs)
+
+
 @DATASETS.register_module()
 class DummyDataset(Dataset):
     METAINFO = dict()  # type: ignore
@@ -171,3 +181,79 @@ class TestEMAHook(TestCase):
             custom_hooks=[dict(type='EMAHook')],
             experiment_name='test4')
         runner.test()
+
+        # Test does not load ckpt strict_loadly.
+        # Test load checkpoint without ema_state_dict
+        runner = Runner(
+            model=ToyModel2(),
+            test_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=3,
+                num_workers=0),
+            test_evaluator=evaluator,
+            test_cfg=dict(),
+            work_dir=self.temp_dir.name,
+            load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
+            default_hooks=dict(logger=None),
+            custom_hooks=[dict(type='EMAHook', strict_load=False)],
+            experiment_name='test5')
+        runner.test()
+
+        # Test enable ema at 5 epochs.
+        runner = Runner(
+            model=model,
+            train_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=3,
+                num_workers=0),
+            val_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=False),
+                batch_size=3,
+                num_workers=0),
+            val_evaluator=evaluator,
+            work_dir=self.temp_dir.name,
+            optim_wrapper=OptimWrapper(
+                torch.optim.Adam(ToyModel().parameters())),
+            train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1),
+            val_cfg=dict(),
+            default_hooks=dict(logger=None),
+            custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
+            experiment_name='test6')
+        runner.train()
+        state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
+        self.assertNotIn('ema_state_dict', state_dict)
+        state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
+        self.assertIn('ema_state_dict', state_dict)
+
+        # Test enable ema at 5 iterations.
+        runner = Runner(
+            model=model,
+            train_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=3,
+                num_workers=0),
+            val_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=False),
+                batch_size=3,
+                num_workers=0),
+            val_evaluator=evaluator,
+            work_dir=self.temp_dir.name,
+            optim_wrapper=OptimWrapper(
+                torch.optim.Adam(ToyModel().parameters())),
+            train_cfg=dict(by_epoch=False, max_iters=10, val_interval=1),
+            val_cfg=dict(),
+            default_hooks=dict(
+                checkpoint=dict(
+                    type='CheckpointHook', interval=1, by_epoch=False)),
+            custom_hooks=[dict(type='EMAHook', begin_iter=5)],
+            experiment_name='test7')
+        runner.train()
+        state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
+        self.assertNotIn('ema_state_dict', state_dict)
+        state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
+        self.assertIn('ema_state_dict', state_dict)