diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 8daa58552a4f76eb11343e8d2d4e16d7efbf834b..2686a89ebd47748da9564fcd9869a298243581c8 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -70,18 +70,20 @@ class EMAHook(Hook): def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """Save ema parameters to checkpoint.""" - # save ema parameters to the source model's state dict so that we can - # directly load the averaged model weights for deployment. - self._swap_ema_parameters() checkpoint['ema_state_dict'] = self.ema_model.state_dict() - self._swap_ema_parameters() + # Save ema parameters to the source model's state dict so that we can + # directly load the averaged model weights for deployment. + # Swapping the state_dict key-values instead of swapping model + # parameters because the state_dict is a shallow copy of model + # parameters. + self._swap_ema_state_dict(checkpoint) def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint.""" - self.ema_model.load_state_dict(checkpoint['ema_state_dict']) # The original model parameters are actually saved in ema field. # swap the weights back to resume ema state. - self._swap_ema_parameters() + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict(checkpoint['ema_state_dict']) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" @@ -98,3 +100,13 @@ class EMAHook(Hook): tmp = p_avg.data.clone() p_avg.data.copy_(p_src.data) p_src.data.copy_(tmp) + + def _swap_ema_state_dict(self, checkpoint): + """Swap the state dict values of model with ema_model.""" + model_state = checkpoint['state_dict'] + ema_state = checkpoint['ema_state_dict'] + for k in ema_state: + if k[:7] == 'module.': + tmp = ema_state[k] + ema_state[k] = model_state[k[7:]] + model_state[k[7:]] = tmp diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index f4ed11861b6821af7c1a2e4d1389f95e67ac1d9b..9afef4e607d5e4e2fc8950878eaa9a5e7a687ad4 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -187,7 +187,7 @@ class TestAveragedModel(TestCase): if param.size() != torch.Size([]) ] for p, p_avg in zip(params, averaged_params): - p.detach().add_(torch.randn_like(p)) + p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: @@ -234,7 +234,7 @@ class TestAveragedModel(TestCase): if param.size() != torch.Size([]) ] for p, p_avg in zip(params, averaged_params): - p.detach().add_(torch.randn_like(p)) + p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) elif i % interval == 0: