From 172b9ded4ade88bec81ee9efdc9ad1587abf482f Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Mon, 30 May 2022 16:51:06 +0800 Subject: [PATCH] [Fix] Fix ema state dict swapping in EMAHook and torch1.5 ut. (#266) * [Fix] Fix ema state dict swapping in EMAHook. * fix pt1.5 ut * add more comments --- mmengine/hooks/ema_hook.py | 24 ++++++++++++++++++------ tests/test_model/test_averaged_model.py | 4 ++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 8daa5855..2686a89e 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -70,18 +70,20 @@ class EMAHook(Hook): def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """Save ema parameters to checkpoint.""" - # save ema parameters to the source model's state dict so that we can - # directly load the averaged model weights for deployment. - self._swap_ema_parameters() checkpoint['ema_state_dict'] = self.ema_model.state_dict() - self._swap_ema_parameters() + # Save ema parameters to the source model's state dict so that we can + # directly load the averaged model weights for deployment. + # Swapping the state_dict key-values instead of swapping model + # parameters because the state_dict is a shallow copy of model + # parameters. + self._swap_ema_state_dict(checkpoint) def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint.""" - self.ema_model.load_state_dict(checkpoint['ema_state_dict']) # The original model parameters are actually saved in ema field. # swap the weights back to resume ema state. - self._swap_ema_parameters() + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict(checkpoint['ema_state_dict']) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" @@ -98,3 +100,13 @@ class EMAHook(Hook): tmp = p_avg.data.clone() p_avg.data.copy_(p_src.data) p_src.data.copy_(tmp) + + def _swap_ema_state_dict(self, checkpoint): + """Swap the state dict values of model with ema_model.""" + model_state = checkpoint['state_dict'] + ema_state = checkpoint['ema_state_dict'] + for k in ema_state: + if k[:7] == 'module.': + tmp = ema_state[k] + ema_state[k] = model_state[k[7:]] + model_state[k[7:]] = tmp diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index f4ed1186..9afef4e6 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -187,7 +187,7 @@ class TestAveragedModel(TestCase): if param.size() != torch.Size([]) ] for p, p_avg in zip(params, averaged_params): - p.detach().add_(torch.randn_like(p)) + p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) else: @@ -234,7 +234,7 @@ class TestAveragedModel(TestCase): if param.size() != torch.Size([]) ] for p, p_avg in zip(params, averaged_params): - p.detach().add_(torch.randn_like(p)) + p.add(torch.randn_like(p)) if i == 0: updated_averaged_params.append(p.clone()) elif i % interval == 0: -- GitLab