From 172b9ded4ade88bec81ee9efdc9ad1587abf482f Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Mon, 30 May 2022 16:51:06 +0800
Subject: [PATCH] [Fix] Fix ema state dict swapping in EMAHook and torch1.5 ut.
 (#266)

* [Fix] Fix ema state dict swapping in EMAHook.

* fix pt1.5 ut

* add more comments
---
 mmengine/hooks/ema_hook.py              | 24 ++++++++++++++++++------
 tests/test_model/test_averaged_model.py |  4 ++--
 2 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py
index 8daa5855..2686a89e 100644
--- a/mmengine/hooks/ema_hook.py
+++ b/mmengine/hooks/ema_hook.py
@@ -70,18 +70,20 @@ class EMAHook(Hook):
 
     def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
         """Save ema parameters to checkpoint."""
-        # save ema parameters to the source model's state dict so that we can
-        # directly load the averaged model weights for deployment.
-        self._swap_ema_parameters()
         checkpoint['ema_state_dict'] = self.ema_model.state_dict()
-        self._swap_ema_parameters()
+        # Save ema parameters to the source model's state dict so that we can
+        # directly load the averaged model weights for deployment.
+        # Swapping the state_dict key-values instead of swapping model
+        # parameters because the state_dict is a shallow copy of model
+        # parameters.
+        self._swap_ema_state_dict(checkpoint)
 
     def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
         """Resume ema parameters from checkpoint."""
-        self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
         # The original model parameters are actually saved in ema field.
         # swap the weights back to resume ema state.
-        self._swap_ema_parameters()
+        self._swap_ema_state_dict(checkpoint)
+        self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
 
     def _swap_ema_parameters(self) -> None:
         """Swap the parameter of model with ema_model."""
@@ -98,3 +100,13 @@ class EMAHook(Hook):
             tmp = p_avg.data.clone()
             p_avg.data.copy_(p_src.data)
             p_src.data.copy_(tmp)
+
+    def _swap_ema_state_dict(self, checkpoint):
+        """Swap the state dict values of model with ema_model."""
+        model_state = checkpoint['state_dict']
+        ema_state = checkpoint['ema_state_dict']
+        for k in ema_state:
+            if k[:7] == 'module.':
+                tmp = ema_state[k]
+                ema_state[k] = model_state[k[7:]]
+                model_state[k[7:]] = tmp
diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py
index f4ed1186..9afef4e6 100644
--- a/tests/test_model/test_averaged_model.py
+++ b/tests/test_model/test_averaged_model.py
@@ -187,7 +187,7 @@ class TestAveragedModel(TestCase):
                 if param.size() != torch.Size([])
             ]
             for p, p_avg in zip(params, averaged_params):
-                p.detach().add_(torch.randn_like(p))
+                p.add(torch.randn_like(p))
                 if i == 0:
                     updated_averaged_params.append(p.clone())
                 else:
@@ -234,7 +234,7 @@ class TestAveragedModel(TestCase):
                 if param.size() != torch.Size([])
             ]
             for p, p_avg in zip(params, averaged_params):
-                p.detach().add_(torch.randn_like(p))
+                p.add(torch.randn_like(p))
                 if i == 0:
                     updated_averaged_params.append(p.clone())
                 elif i % interval == 0:
-- 
GitLab