diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index d4834b7c90224842b8f79ae1286eb9b8b84a92ee..d0d5e3acb61b0bd69b72d167975eb499a9a70860 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -71,6 +71,12 @@ class EMAHook(Hook): self.ema_model = MODELS.build( self.ema_cfg, default_args=dict(model=self.src_model)) + def before_train(self, runner) -> None: + """Check the begin_epoch/iter is smaller than max_epochs/iters. + + Args: + runner (Runner): The runner of the training process. + """ if self.enabled_by_epoch: assert self.begin_epoch <= runner.max_epochs, ( 'self.begin_epoch should be smaller than runner.max_epochs: ' @@ -96,6 +102,11 @@ class EMAHook(Hook): """ if self._ema_started(runner): self.ema_model.update_parameters(self.src_model) + else: + ema_params = self.ema_model.module.state_dict() + src_params = self.src_model.state_dict() + for k, p in ema_params.items(): + p.data.copy_(src_params[k].data) def before_val_epoch(self, runner) -> None: """We load parameter values from ema model to source model before @@ -104,8 +115,7 @@ class EMAHook(Hook): Args: runner (Runner): The runner of the training process. """ - if self._ema_started(runner): - self._swap_ema_parameters() + self._swap_ema_parameters() def after_val_epoch(self, runner, @@ -118,8 +128,7 @@ class EMAHook(Hook): metrics on validation dataset. The keys are the names of the metrics, and the values are corresponding results. """ - if self._ema_started(runner): - self._swap_ema_parameters() + self._swap_ema_parameters() def before_test_epoch(self, runner) -> None: """We load parameter values from ema model to source model before test. @@ -127,8 +136,7 @@ class EMAHook(Hook): Args: runner (Runner): The runner of the training process. """ - if self._ema_started(runner): - self._swap_ema_parameters() + self._swap_ema_parameters() def after_test_epoch(self, runner, @@ -141,8 +149,7 @@ class EMAHook(Hook): metrics on test dataset. The keys are the names of the metrics, and the values are corresponding results. """ - if self._ema_started(runner): - self._swap_ema_parameters() + self._swap_ema_parameters() def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """Save ema parameters to checkpoint. @@ -150,14 +157,13 @@ class EMAHook(Hook): Args: runner (Runner): The runner of the testing process. """ - if self._ema_started(runner): - checkpoint['ema_state_dict'] = self.ema_model.state_dict() - # Save ema parameters to the source model's state dict so that we - # can directly load the averaged model weights for deployment. - # Swapping the state_dict key-values instead of swapping model - # parameters because the state_dict is a shallow copy of model - # parameters. - self._swap_ema_state_dict(checkpoint) + checkpoint['ema_state_dict'] = self.ema_model.state_dict() + # Save ema parameters to the source model's state dict so that we + # can directly load the averaged model weights for deployment. + # Swapping the state_dict key-values instead of swapping model + # parameters because the state_dict is a shallow copy of model + # parameters. + self._swap_ema_state_dict(checkpoint) def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint. @@ -165,23 +171,22 @@ class EMAHook(Hook): Args: runner (Runner): The runner of the testing process. """ - if self._ema_started(runner): - if 'ema_state_dict' in checkpoint: - # The original model parameters are actually saved in ema - # field swap the weights back to resume ema state. - self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict( - checkpoint['ema_state_dict'], strict=self.strict_load) - - # Support load checkpoint without ema state dict. - else: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) - self.ema_model.module.load_state_dict( - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + if 'ema_state_dict' in checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + + # Support load checkpoint without ema state dict. + else: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) + self.ema_model.module.load_state_dict( + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 8a47e8e07157c5a6f53f24c4f1507504f4e0aaa9..be1abc30bda2352f5aceee00f3ccc11e1cd5d2c2 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -106,6 +106,11 @@ class BaseAveragedModel(nn.Module): self.avg_func(p_avg.data, src_parameters[k].data.to(device), self.steps) + if not self.update_buffers: + # If not update the buffers, + # keep the buffers in sync with the source model. + for b_avg, b_src in zip(self.module.buffers(), model.buffers()): + b_avg.data.copy_(b_src.data.to(b_avg.device)) self.steps += 1 diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 4fcced7dbca94ecd6d74c165e3fb2e27376929de..4b7e7d7bca2928e1f9dcc04518e2d2ed003497d7 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -14,6 +14,7 @@ from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.optim import OptimWrapper from mmengine.registry import DATASETS, MODEL_WRAPPERS from mmengine.runner import Runner +from mmengine.testing import assert_allclose class ToyModel(nn.Module): @@ -225,9 +226,13 @@ class TestEMAHook(TestCase): custom_hooks=[dict(type='EMAHook', begin_epoch=5)], experiment_name='test6') runner.train() - state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth')) - self.assertNotIn('ema_state_dict', state_dict) - state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth')) + state_dict = torch.load( + osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + self.assertIn('ema_state_dict', state_dict) + for k, v in state_dict['state_dict'].items(): + assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) + state_dict = torch.load( + osp.join(self.temp_dir.name, 'epoch_5.pth'), map_location='cpu') self.assertIn('ema_state_dict', state_dict) # Test enable ema at 5 iterations. @@ -255,7 +260,11 @@ class TestEMAHook(TestCase): custom_hooks=[dict(type='EMAHook', begin_iter=5)], experiment_name='test7') runner.train() - state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth')) - self.assertNotIn('ema_state_dict', state_dict) - state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth')) + state_dict = torch.load( + osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + self.assertIn('ema_state_dict', state_dict) + for k, v in state_dict['state_dict'].items(): + assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) + state_dict = torch.load( + osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') self.assertIn('ema_state_dict', state_dict)