Skip to content
Snippets Groups Projects
Unverified Commit 172b9ded authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix ema state dict swapping in EMAHook and torch1.5 ut. (#266)

* [Fix] Fix ema state dict swapping in EMAHook.

* fix pt1.5 ut

* add more comments
parent 40daf46a
No related branches found
No related tags found
No related merge requests found
...@@ -70,18 +70,20 @@ class EMAHook(Hook): ...@@ -70,18 +70,20 @@ class EMAHook(Hook):
def before_save_checkpoint(self, runner, checkpoint: dict) -> None: def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint.""" """Save ema parameters to checkpoint."""
# save ema parameters to the source model's state dict so that we can
# directly load the averaged model weights for deployment.
self._swap_ema_parameters()
checkpoint['ema_state_dict'] = self.ema_model.state_dict() checkpoint['ema_state_dict'] = self.ema_model.state_dict()
self._swap_ema_parameters() # Save ema parameters to the source model's state dict so that we can
# directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None: def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint.""" """Resume ema parameters from checkpoint."""
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
# The original model parameters are actually saved in ema field. # The original model parameters are actually saved in ema field.
# swap the weights back to resume ema state. # swap the weights back to resume ema state.
self._swap_ema_parameters() self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
def _swap_ema_parameters(self) -> None: def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model.""" """Swap the parameter of model with ema_model."""
...@@ -98,3 +100,13 @@ class EMAHook(Hook): ...@@ -98,3 +100,13 @@ class EMAHook(Hook):
tmp = p_avg.data.clone() tmp = p_avg.data.clone()
p_avg.data.copy_(p_src.data) p_avg.data.copy_(p_src.data)
p_src.data.copy_(tmp) p_src.data.copy_(tmp)
def _swap_ema_state_dict(self, checkpoint):
"""Swap the state dict values of model with ema_model."""
model_state = checkpoint['state_dict']
ema_state = checkpoint['ema_state_dict']
for k in ema_state:
if k[:7] == 'module.':
tmp = ema_state[k]
ema_state[k] = model_state[k[7:]]
model_state[k[7:]] = tmp
...@@ -187,7 +187,7 @@ class TestAveragedModel(TestCase): ...@@ -187,7 +187,7 @@ class TestAveragedModel(TestCase):
if param.size() != torch.Size([]) if param.size() != torch.Size([])
] ]
for p, p_avg in zip(params, averaged_params): for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p)) p.add(torch.randn_like(p))
if i == 0: if i == 0:
updated_averaged_params.append(p.clone()) updated_averaged_params.append(p.clone())
else: else:
...@@ -234,7 +234,7 @@ class TestAveragedModel(TestCase): ...@@ -234,7 +234,7 @@ class TestAveragedModel(TestCase):
if param.size() != torch.Size([]) if param.size() != torch.Size([])
] ]
for p, p_avg in zip(params, averaged_params): for p, p_avg in zip(params, averaged_params):
p.detach().add_(torch.randn_like(p)) p.add(torch.randn_like(p))
if i == 0: if i == 0:
updated_averaged_params.append(p.clone()) updated_averaged_params.append(p.clone())
elif i % interval == 0: elif i % interval == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment