Skip to content
Snippets Groups Projects
Unverified Commit 8d25dbde authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix EMAHook trigger train loop and AveragedModel sync buffer. (#467)

* [Fix] Fix EMAHook trigger train loop init during testing.

* fix sync buffer

* update ut

* fix sync buffer

* fix sync buffer
parent 18a0338c
No related branches found
No related tags found
No related merge requests found
...@@ -71,6 +71,12 @@ class EMAHook(Hook): ...@@ -71,6 +71,12 @@ class EMAHook(Hook):
self.ema_model = MODELS.build( self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model)) self.ema_cfg, default_args=dict(model=self.src_model))
def before_train(self, runner) -> None:
"""Check the begin_epoch/iter is smaller than max_epochs/iters.
Args:
runner (Runner): The runner of the training process.
"""
if self.enabled_by_epoch: if self.enabled_by_epoch:
assert self.begin_epoch <= runner.max_epochs, ( assert self.begin_epoch <= runner.max_epochs, (
'self.begin_epoch should be smaller than runner.max_epochs: ' 'self.begin_epoch should be smaller than runner.max_epochs: '
...@@ -96,6 +102,11 @@ class EMAHook(Hook): ...@@ -96,6 +102,11 @@ class EMAHook(Hook):
""" """
if self._ema_started(runner): if self._ema_started(runner):
self.ema_model.update_parameters(self.src_model) self.ema_model.update_parameters(self.src_model)
else:
ema_params = self.ema_model.module.state_dict()
src_params = self.src_model.state_dict()
for k, p in ema_params.items():
p.data.copy_(src_params[k].data)
def before_val_epoch(self, runner) -> None: def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before """We load parameter values from ema model to source model before
...@@ -104,8 +115,7 @@ class EMAHook(Hook): ...@@ -104,8 +115,7 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def after_val_epoch(self, def after_val_epoch(self,
runner, runner,
...@@ -118,8 +128,7 @@ class EMAHook(Hook): ...@@ -118,8 +128,7 @@ class EMAHook(Hook):
metrics on validation dataset. The keys are the names of the metrics on validation dataset. The keys are the names of the
metrics, and the values are corresponding results. metrics, and the values are corresponding results.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def before_test_epoch(self, runner) -> None: def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before test. """We load parameter values from ema model to source model before test.
...@@ -127,8 +136,7 @@ class EMAHook(Hook): ...@@ -127,8 +136,7 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def after_test_epoch(self, def after_test_epoch(self,
runner, runner,
...@@ -141,8 +149,7 @@ class EMAHook(Hook): ...@@ -141,8 +149,7 @@ class EMAHook(Hook):
metrics on test dataset. The keys are the names of the metrics on test dataset. The keys are the names of the
metrics, and the values are corresponding results. metrics, and the values are corresponding results.
""" """
if self._ema_started(runner): self._swap_ema_parameters()
self._swap_ema_parameters()
def before_save_checkpoint(self, runner, checkpoint: dict) -> None: def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint. """Save ema parameters to checkpoint.
...@@ -150,14 +157,13 @@ class EMAHook(Hook): ...@@ -150,14 +157,13 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
""" """
if self._ema_started(runner): checkpoint['ema_state_dict'] = self.ema_model.state_dict()
checkpoint['ema_state_dict'] = self.ema_model.state_dict() # Save ema parameters to the source model's state dict so that we
# Save ema parameters to the source model's state dict so that we # can directly load the averaged model weights for deployment.
# can directly load the averaged model weights for deployment. # Swapping the state_dict key-values instead of swapping model
# Swapping the state_dict key-values instead of swapping model # parameters because the state_dict is a shallow copy of model
# parameters because the state_dict is a shallow copy of model # parameters.
# parameters. self._swap_ema_state_dict(checkpoint)
self._swap_ema_state_dict(checkpoint)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None: def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint. """Resume ema parameters from checkpoint.
...@@ -165,23 +171,22 @@ class EMAHook(Hook): ...@@ -165,23 +171,22 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
""" """
if self._ema_started(runner): if 'ema_state_dict' in checkpoint:
if 'ema_state_dict' in checkpoint: # The original model parameters are actually saved in ema
# The original model parameters are actually saved in ema # field swap the weights back to resume ema state.
# field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint)
self._swap_ema_state_dict(checkpoint) self.ema_model.load_state_dict(
self.ema_model.load_state_dict( checkpoint['ema_state_dict'], strict=self.strict_load)
checkpoint['ema_state_dict'], strict=self.strict_load)
# Support load checkpoint without ema state dict.
# Support load checkpoint without ema state dict. else:
else: print_log(
print_log( 'There is no `ema_state_dict` in checkpoint. '
'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the '
'`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING)
'initial `ema_state_dict`', 'current', logging.WARNING) self.ema_model.module.load_state_dict(
self.ema_model.module.load_state_dict( copy.deepcopy(checkpoint['state_dict']),
copy.deepcopy(checkpoint['state_dict']), strict=self.strict_load)
strict=self.strict_load)
def _swap_ema_parameters(self) -> None: def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model.""" """Swap the parameter of model with ema_model."""
......
...@@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module): ...@@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module):
self.avg_func(p_avg.data, self.avg_func(p_avg.data,
src_parameters[k].data.to(device), src_parameters[k].data.to(device),
self.steps) self.steps)
if not self.update_buffers:
# If not update the buffers,
# keep the buffers in sync with the source model.
for b_avg, b_src in zip(self.module.buffers(), model.buffers()):
b_avg.data.copy_(b_src.data.to(b_avg.device))
self.steps += 1 self.steps += 1
......
...@@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage ...@@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage
from mmengine.optim import OptimWrapper from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, MODEL_WRAPPERS from mmengine.registry import DATASETS, MODEL_WRAPPERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmengine.testing import assert_allclose
class ToyModel(nn.Module): class ToyModel(nn.Module):
...@@ -225,9 +226,13 @@ class TestEMAHook(TestCase): ...@@ -225,9 +226,13 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_epoch=5)], custom_hooks=[dict(type='EMAHook', begin_epoch=5)],
experiment_name='test6') experiment_name='test6')
runner.train() runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth')) state_dict = torch.load(
self.assertNotIn('ema_state_dict', state_dict) osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth')) self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict) self.assertIn('ema_state_dict', state_dict)
# Test enable ema at 5 iterations. # Test enable ema at 5 iterations.
...@@ -255,7 +260,11 @@ class TestEMAHook(TestCase): ...@@ -255,7 +260,11 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook', begin_iter=5)], custom_hooks=[dict(type='EMAHook', begin_iter=5)],
experiment_name='test7') experiment_name='test7')
runner.train() runner.train()
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth')) state_dict = torch.load(
self.assertNotIn('ema_state_dict', state_dict) osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth')) self.assertIn('ema_state_dict', state_dict)
for k, v in state_dict['state_dict'].items():
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
state_dict = torch.load(
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
self.assertIn('ema_state_dict', state_dict) self.assertIn('ema_state_dict', state_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment