diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py
index 8daa58552a4f76eb11343e8d2d4e16d7efbf834b..2686a89ebd47748da9564fcd9869a298243581c8 100644
--- a/mmengine/hooks/ema_hook.py
+++ b/mmengine/hooks/ema_hook.py
@@ -70,18 +70,20 @@ class EMAHook(Hook):
 
     def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
         """Save ema parameters to checkpoint."""
-        # save ema parameters to the source model's state dict so that we can
-        # directly load the averaged model weights for deployment.
-        self._swap_ema_parameters()
         checkpoint['ema_state_dict'] = self.ema_model.state_dict()
-        self._swap_ema_parameters()
+        # Save ema parameters to the source model's state dict so that we can
+        # directly load the averaged model weights for deployment.
+        # Swapping the state_dict key-values instead of swapping model
+        # parameters because the state_dict is a shallow copy of model
+        # parameters.
+        self._swap_ema_state_dict(checkpoint)
 
     def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
         """Resume ema parameters from checkpoint."""
-        self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
         # The original model parameters are actually saved in ema field.
         # swap the weights back to resume ema state.
-        self._swap_ema_parameters()
+        self._swap_ema_state_dict(checkpoint)
+        self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
 
     def _swap_ema_parameters(self) -> None:
         """Swap the parameter of model with ema_model."""
@@ -98,3 +100,13 @@ class EMAHook(Hook):
             tmp = p_avg.data.clone()
             p_avg.data.copy_(p_src.data)
             p_src.data.copy_(tmp)
+
+    def _swap_ema_state_dict(self, checkpoint):
+        """Swap the state dict values of model with ema_model."""
+        model_state = checkpoint['state_dict']
+        ema_state = checkpoint['ema_state_dict']
+        for k in ema_state:
+            if k[:7] == 'module.':
+                tmp = ema_state[k]
+                ema_state[k] = model_state[k[7:]]
+                model_state[k[7:]] = tmp
diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py
index f4ed11861b6821af7c1a2e4d1389f95e67ac1d9b..9afef4e607d5e4e2fc8950878eaa9a5e7a687ad4 100644
--- a/tests/test_model/test_averaged_model.py
+++ b/tests/test_model/test_averaged_model.py
@@ -187,7 +187,7 @@ class TestAveragedModel(TestCase):
                 if param.size() != torch.Size([])
             ]
             for p, p_avg in zip(params, averaged_params):
-                p.detach().add_(torch.randn_like(p))
+                p.add(torch.randn_like(p))
                 if i == 0:
                     updated_averaged_params.append(p.clone())
                 else:
@@ -234,7 +234,7 @@ class TestAveragedModel(TestCase):
                 if param.size() != torch.Size([])
             ]
             for p, p_avg in zip(params, averaged_params):
-                p.detach().add_(torch.randn_like(p))
+                p.add(torch.randn_like(p))
                 if i == 0:
                     updated_averaged_params.append(p.clone())
                 elif i % interval == 0: