From a6f5297727f0ce9a6967f00499849a05c695665e Mon Sep 17 00:00:00 2001
From: takuoko <to78314910@gmail.com>
Date: Fri, 9 Sep 2022 12:41:12 +0900
Subject: [PATCH] [fix] EMAHook load state dict (#507)

* fix ema load_state_dict

* fix ema load_state_dict

* fix for test

* fix by review

* fix resume and keys
---
 mmengine/hooks/ema_hook.py        | 15 +++++++++------
 tests/test_hooks/test_ema_hook.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 6 deletions(-)

diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py
index d0d5e3ac..f2712e6f 100644
--- a/mmengine/hooks/ema_hook.py
+++ b/mmengine/hooks/ema_hook.py
@@ -7,6 +7,7 @@ from typing import Dict, Optional
 from mmengine.logging import print_log
 from mmengine.model import is_model_wrapper
 from mmengine.registry import HOOKS, MODELS
+from mmengine.runner.checkpoint import _load_checkpoint_to_model
 from .hook import DATA_BATCH, Hook
 
 
@@ -171,7 +172,7 @@ class EMAHook(Hook):
         Args:
             runner (Runner): The runner of the testing process.
         """
-        if 'ema_state_dict' in checkpoint:
+        if 'ema_state_dict' in checkpoint and runner._resume:
             # The original model parameters are actually saved in ema
             # field swap the weights back to resume ema state.
             self._swap_ema_state_dict(checkpoint)
@@ -180,11 +181,13 @@ class EMAHook(Hook):
 
         # Support load checkpoint without ema state dict.
         else:
-            print_log(
-                'There is no `ema_state_dict` in checkpoint. '
-                '`EMAHook` will make a copy of `state_dict` as the '
-                'initial `ema_state_dict`', 'current', logging.WARNING)
-            self.ema_model.module.load_state_dict(
+            if runner._resume:
+                print_log(
+                    'There is no `ema_state_dict` in checkpoint. '
+                    '`EMAHook` will make a copy of `state_dict` as the '
+                    'initial `ema_state_dict`', 'current', logging.WARNING)
+            _load_checkpoint_to_model(
+                self.ema_model.module,
                 copy.deepcopy(checkpoint['state_dict']),
                 strict=self.strict_load)
 
diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py
index 4b7e7d7b..3952033c 100644
--- a/tests/test_hooks/test_ema_hook.py
+++ b/tests/test_hooks/test_ema_hook.py
@@ -56,6 +56,16 @@ class ToyModel2(BaseModel, ToyModel):
         return super(BaseModel, self).forward(*args, **kwargs)
 
 
+class ToyModel3(BaseModel, ToyModel):
+
+    def __init__(self):
+        super().__init__()
+        self.linear1 = nn.Linear(2, 2)
+
+    def forward(self, *args, **kwargs):
+        return super(BaseModel, self).forward(*args, **kwargs)
+
+
 @DATASETS.register_module()
 class DummyDataset(Dataset):
     METAINFO = dict()  # type: ignore
@@ -203,6 +213,25 @@ class TestEMAHook(TestCase):
             experiment_name='test5')
         runner.test()
 
+        # Test does not load ckpt strict_loadly.
+        # Test load checkpoint without ema_state_dict
+        # Test with different size head.
+        runner = Runner(
+            model=ToyModel3(),
+            test_dataloader=dict(
+                dataset=dict(type='DummyDataset'),
+                sampler=dict(type='DefaultSampler', shuffle=True),
+                batch_size=3,
+                num_workers=0),
+            test_evaluator=evaluator,
+            test_cfg=dict(),
+            work_dir=self.temp_dir.name,
+            load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
+            default_hooks=dict(logger=None),
+            custom_hooks=[dict(type='EMAHook', strict_load=False)],
+            experiment_name='test5')
+        runner.test()
+
         # Test enable ema at 5 epochs.
         runner = Runner(
             model=model,
-- 
GitLab