From a6f5297727f0ce9a6967f00499849a05c695665e Mon Sep 17 00:00:00 2001 From: takuoko <to78314910@gmail.com> Date: Fri, 9 Sep 2022 12:41:12 +0900 Subject: [PATCH] [fix] EMAHook load state dict (#507) * fix ema load_state_dict * fix ema load_state_dict * fix for test * fix by review * fix resume and keys --- mmengine/hooks/ema_hook.py | 15 +++++++++------ tests/test_hooks/test_ema_hook.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index d0d5e3ac..f2712e6f 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,6 +7,7 @@ from typing import Dict, Optional from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS +from mmengine.runner.checkpoint import _load_checkpoint_to_model from .hook import DATA_BATCH, Hook @@ -171,7 +172,7 @@ class EMAHook(Hook): Args: runner (Runner): The runner of the testing process. """ - if 'ema_state_dict' in checkpoint: + if 'ema_state_dict' in checkpoint and runner._resume: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) @@ -180,11 +181,13 @@ class EMAHook(Hook): # Support load checkpoint without ema state dict. else: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) - self.ema_model.module.load_state_dict( + if runner._resume: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) + _load_checkpoint_to_model( + self.ema_model.module, copy.deepcopy(checkpoint['state_dict']), strict=self.strict_load) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 4b7e7d7b..3952033c 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -56,6 +56,16 @@ class ToyModel2(BaseModel, ToyModel): return super(BaseModel, self).forward(*args, **kwargs) +class ToyModel3(BaseModel, ToyModel): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 2) + + def forward(self, *args, **kwargs): + return super(BaseModel, self).forward(*args, **kwargs) + + @DATASETS.register_module() class DummyDataset(Dataset): METAINFO = dict() # type: ignore @@ -203,6 +213,25 @@ class TestEMAHook(TestCase): experiment_name='test5') runner.test() + # Test does not load ckpt strict_loadly. + # Test load checkpoint without ema_state_dict + # Test with different size head. + runner = Runner( + model=ToyModel3(), + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', strict_load=False)], + experiment_name='test5') + runner.test() + # Test enable ema at 5 epochs. runner = Runner( model=model, -- GitLab