diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py
index d4834b7c90224842b8f79ae1286eb9b8b84a92ee..d0d5e3acb61b0bd69b72d167975eb499a9a70860 100644
--- a/mmengine/hooks/ema_hook.py
+++ b/mmengine/hooks/ema_hook.py
@@ -71,6 +71,12 @@ class EMAHook(Hook):
         self.ema_model = MODELS.build(
             self.ema_cfg, default_args=dict(model=self.src_model))
 
+    def before_train(self, runner) -> None:
+        """Check the begin_epoch/iter is smaller than max_epochs/iters.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
         if self.enabled_by_epoch:
             assert self.begin_epoch <= runner.max_epochs, (
                 'self.begin_epoch should be smaller than runner.max_epochs: '
@@ -96,6 +102,11 @@ class EMAHook(Hook):
         """
         if self._ema_started(runner):
             self.ema_model.update_parameters(self.src_model)
+        else:
+            ema_params = self.ema_model.module.state_dict()
+            src_params = self.src_model.state_dict()
+            for k, p in ema_params.items():
+                p.data.copy_(src_params[k].data)
 
     def before_val_epoch(self, runner) -> None:
         """We load parameter values from ema model to source model before
@@ -104,8 +115,7 @@ class EMAHook(Hook):
         Args:
             runner (Runner): The runner of the training process.
         """
-        if self._ema_started(runner):
-            self._swap_ema_parameters()
+        self._swap_ema_parameters()
 
     def after_val_epoch(self,
                         runner,
@@ -118,8 +128,7 @@ class EMAHook(Hook):
                 metrics on validation dataset. The keys are the names of the
                 metrics, and the values are corresponding results.
         """
-        if self._ema_started(runner):
-            self._swap_ema_parameters()
+        self._swap_ema_parameters()
 
     def before_test_epoch(self, runner) -> None:
         """We load parameter values from ema model to source model before test.
@@ -127,8 +136,7 @@ class EMAHook(Hook):
         Args:
             runner (Runner): The runner of the training process.
         """
-        if self._ema_started(runner):
-            self._swap_ema_parameters()
+        self._swap_ema_parameters()
 
     def after_test_epoch(self,
                          runner,
@@ -141,8 +149,7 @@ class EMAHook(Hook):
                 metrics on test dataset. The keys are the names of the
                 metrics, and the values are corresponding results.
         """
-        if self._ema_started(runner):
-            self._swap_ema_parameters()
+        self._swap_ema_parameters()
 
     def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
         """Save ema parameters to checkpoint.
@@ -150,14 +157,13 @@ class EMAHook(Hook):
         Args:
             runner (Runner): The runner of the testing process.
         """
-        if self._ema_started(runner):
-            checkpoint['ema_state_dict'] = self.ema_model.state_dict()
-            # Save ema parameters to the source model's state dict so that we
-            # can directly load the averaged model weights for deployment.
-            # Swapping the state_dict key-values instead of swapping model
-            # parameters because the state_dict is a shallow copy of model
-            # parameters.
-            self._swap_ema_state_dict(checkpoint)
+        checkpoint['ema_state_dict'] = self.ema_model.state_dict()
+        # Save ema parameters to the source model's state dict so that we
+        # can directly load the averaged model weights for deployment.
+        # Swapping the state_dict key-values instead of swapping model
+        # parameters because the state_dict is a shallow copy of model
+        # parameters.
+        self._swap_ema_state_dict(checkpoint)
 
     def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
         """Resume ema parameters from checkpoint.
@@ -165,23 +171,22 @@ class EMAHook(Hook):
         Args:
             runner (Runner): The runner of the testing process.
         """
-        if self._ema_started(runner):
-            if 'ema_state_dict' in checkpoint:
-                # The original model parameters are actually saved in ema
-                # field swap the weights back to resume ema state.
-                self._swap_ema_state_dict(checkpoint)
-                self.ema_model.load_state_dict(
-                    checkpoint['ema_state_dict'], strict=self.strict_load)
-
-            # Support load checkpoint without ema state dict.
-            else:
-                print_log(
-                    'There is no `ema_state_dict` in checkpoint. '
-                    '`EMAHook` will make a copy of `state_dict` as the '
-                    'initial `ema_state_dict`', 'current', logging.WARNING)
-                self.ema_model.module.load_state_dict(
-                    copy.deepcopy(checkpoint['state_dict']),
-                    strict=self.strict_load)
+        if 'ema_state_dict' in checkpoint:
+            # The original model parameters are actually saved in ema
+            # field swap the weights back to resume ema state.
+            self._swap_ema_state_dict(checkpoint)
+            self.ema_model.load_state_dict(
+                checkpoint['ema_state_dict'], strict=self.strict_load)
+
+        # Support load checkpoint without ema state dict.
+        else:
+            print_log(
+                'There is no `ema_state_dict` in checkpoint. '
+                '`EMAHook` will make a copy of `state_dict` as the '
+                'initial `ema_state_dict`', 'current', logging.WARNING)
+            self.ema_model.module.load_state_dict(
+                copy.deepcopy(checkpoint['state_dict']),
+                strict=self.strict_load)
 
     def _swap_ema_parameters(self) -> None:
         """Swap the parameter of model with ema_model."""
diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py
index 8a47e8e07157c5a6f53f24c4f1507504f4e0aaa9..be1abc30bda2352f5aceee00f3ccc11e1cd5d2c2 100644
--- a/mmengine/model/averaged_model.py
+++ b/mmengine/model/averaged_model.py
@@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
                     self.avg_func(p_avg.data,
                                   src_parameters[k].data.to(device),
                                   self.steps)
+        if not self.update_buffers:
+            # If not update the buffers,
+            # keep the buffers in sync with the source model.
+            for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
+                b_avg.data.copy_(b_src.data.to(b_avg.device))
         self.steps += 1
 
 
diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py
index 4fcced7dbca94ecd6d74c165e3fb2e27376929de..4b7e7d7bca2928e1f9dcc04518e2d2ed003497d7 100644
--- a/tests/test_hooks/test_ema_hook.py
+++ b/tests/test_hooks/test_ema_hook.py
@@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
 from mmengine.optim import OptimWrapper
 from mmengine.registry import DATASETS, MODEL_WRAPPERS
 from mmengine.runner import Runner
+from mmengine.testing import assert_allclose
 
 
 class ToyModel(nn.Module):
@@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
             custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
             experiment_name='test6')
         runner.train()
-        state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'))
-        self.assertNotIn('ema_state_dict', state_dict)
-        state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth'))
+        state_dict = torch.load(
+            osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
+        self.assertIn('ema_state_dict', state_dict)
+        for k, v in state_dict['state_dict'].items():
+            assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
+        state_dict = torch.load(
+            osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
         self.assertIn('ema_state_dict', state_dict)
 
         # Test enable ema at 5 iterations.
@@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
             custom_hooks=[dict(type='EMAHook', begin_iter=5)],
             experiment_name='test7')
         runner.train()
-        state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'))
-        self.assertNotIn('ema_state_dict', state_dict)
-        state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'))
+        state_dict = torch.load(
+            osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
+        self.assertIn('ema_state_dict', state_dict)
+        for k, v in state_dict['state_dict'].items():
+            assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
+        state_dict = torch.load(
+            osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
         self.assertIn('ema_state_dict', state_dict)