From 99de0951afc2fc99efed115c8c204f48a092de3e Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 8 Aug 2022 20:26:16 +0800 Subject: [PATCH] [Enhance] EMAHook support does not load checkpoint strictly (#352) * BaseAveragedModel support load ckpt without module prefix * refine docstring * allow EMAHook does not load ckpt strictly * add unit test for strict argument of EMAHook * sync remote * sync remote * clean the code * ema hook supports setting start iter * fix unit test * fix as comment * fix as comment * describe kwargs --- mmengine/hooks/ema_hook.py | 122 +++++++++++++++++++++++-------- tests/test_hook/test_ema_hook.py | 86 ++++++++++++++++++++++ 2 files changed, 178 insertions(+), 30 deletions(-) diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 92afd96b..4394744e 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -19,17 +19,44 @@ class EMAHook(Hook): - EMAHook takes priority over CheckpointHook. - The original model parameters are actually saved in ema field after train. + - ``begin_iter`` and ``begin_epoch`` cannot be set at the same time. Args: ema_type (str): The type of EMA strategy to use. You can find the - supported strategies in ``mmengine.model.averaged_model``. - Defaults to 'ExponentialMovingAverage' + supported strategies in :mod:`mmengine.model.averaged_model`. + Defaults to 'ExponentialMovingAverage'. + strict_load (bool): Whether to strictly enforce that the keys of + ``state_dict`` in checkpoint match the keys returned by + ``self.module.state_dict``. Defaults to True. + begin_iter (int): The number of iteration to enable ``EMAHook``. + Defaults to 0. + begin_epoch (int): The number of epoch to enable ``EMAHook``. Defaults + to 0. + **kwargs: Keyword arguments passed to subclasses of + :obj:`BaseAveragedModel` """ priority = 'NORMAL' - def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs): + def __init__(self, + ema_type: str = 'ExponentialMovingAverage', + strict_load: bool = True, + begin_iter: int = 0, + begin_epoch: int = 0, + **kwargs): + self.strict_load = strict_load self.ema_cfg = dict(type=ema_type, **kwargs) + assert not (begin_iter != 0 and begin_epoch != 0), ( + '`begin_iter` and `begin_epoch` should not be both set.') + assert begin_iter >= 0, ( + f'begin_iter must larger than 0, but got begin: {begin_iter}') + assert begin_epoch >= 0, ( + f'begin_epoch must larger than 0, but got begin: {begin_epoch}') + self.begin_iter = begin_iter + self.begin_epoch = begin_epoch + # If `begin_epoch` and `begin_iter` are not set, `EMAHook` will be + # enabled at 0 iteration. + self.enabled_by_epoch = self.begin_epoch > 0 def before_run(self, runner) -> None: """Create an ema copy of the model.""" @@ -40,64 +67,81 @@ class EMAHook(Hook): self.ema_model = MODELS.build( self.ema_cfg, default_args=dict(model=self.src_model)) + if self.enabled_by_epoch: + assert self.begin_epoch <= runner.max_epochs, ( + 'self.begin_epoch should be smaller than runner.max_epochs: ' + f'{runner.max_epochs}, but got begin: {self.begin_epoch}') + else: + assert self.begin_iter <= runner.max_iters, ( + 'self.begin_iter should be smaller than runner.max_iters: ' + f'{runner.max_iters}, but got begin: {self.begin_iter}') + def after_train_iter(self, runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: """Update ema parameter.""" - self.ema_model.update_parameters(self.src_model) + if self._ema_started(runner): + self.ema_model.update_parameters(self.src_model) def before_val_epoch(self, runner) -> None: """We load parameter values from ema model to source model before validation.""" - self._swap_ema_parameters() + if self._ema_started(runner): + self._swap_ema_parameters() def after_val_epoch(self, runner, metrics: Optional[Dict[str, float]] = None) -> None: """We recover source model's parameter from ema model after validation.""" - self._swap_ema_parameters() + if self._ema_started(runner): + self._swap_ema_parameters() def before_test_epoch(self, runner) -> None: """We load parameter values from ema model to source model before test.""" - self._swap_ema_parameters() + if self._ema_started(runner): + self._swap_ema_parameters() def after_test_epoch(self, runner, metrics: Optional[Dict[str, float]] = None) -> None: """We recover source model's parameter from ema model after test.""" - self._swap_ema_parameters() + if self._ema_started(runner): + self._swap_ema_parameters() def before_save_checkpoint(self, runner, checkpoint: dict) -> None: """Save ema parameters to checkpoint.""" - checkpoint['ema_state_dict'] = self.ema_model.state_dict() - # Save ema parameters to the source model's state dict so that we can - # directly load the averaged model weights for deployment. - # Swapping the state_dict key-values instead of swapping model - # parameters because the state_dict is a shallow copy of model - # parameters. - self._swap_ema_state_dict(checkpoint) + if self._ema_started(runner): + checkpoint['ema_state_dict'] = self.ema_model.state_dict() + # Save ema parameters to the source model's state dict so that we + # can directly load the averaged model weights for deployment. + # Swapping the state_dict key-values instead of swapping model + # parameters because the state_dict is a shallow copy of model + # parameters. + self._swap_ema_state_dict(checkpoint) def after_load_checkpoint(self, runner, checkpoint: dict) -> None: """Resume ema parameters from checkpoint.""" - - if 'ema_state_dict' in checkpoint: - # The original model parameters are actually saved in ema field. - # swap the weights back to resume ema state. - self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict(checkpoint['ema_state_dict']) - - # Support load checkpoint without ema state dict. - else: - print_log( - 'There is no `ema_state_dict` in checkpoint. ' - '`EMAHook` will make a copy of `state_dict` as the ' - 'initial `ema_state_dict`', 'current', logging.WARNING) - self.ema_model.module.load_state_dict( - copy.deepcopy(checkpoint['state_dict'])) + if self._ema_started(runner): + if 'ema_state_dict' in checkpoint: + # The original model parameters are actually saved in ema + # field swap the weights back to resume ema state. + self._swap_ema_state_dict(checkpoint) + self.ema_model.load_state_dict( + checkpoint['ema_state_dict'], strict=self.strict_load) + + # Support load checkpoint without ema state dict. + else: + print_log( + 'There is no `ema_state_dict` in checkpoint. ' + '`EMAHook` will make a copy of `state_dict` as the ' + 'initial `ema_state_dict`', 'current', logging.WARNING) + self.ema_model.module.load_state_dict( + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" @@ -124,3 +168,21 @@ class EMAHook(Hook): tmp = ema_state[k] ema_state[k] = model_state[k[7:]] model_state[k[7:]] = tmp + + def _ema_started(self, runner) -> bool: + """Whether ``EMAHook`` has been initialized at current iteration or + epoch. + + :attr:`ema_model` will be initialized when ``runner.iter`` or + ``runner.epoch`` is greater than ``self.begin`` for the first time. + + Args: + runner (Runner): Runner of the training, validation process. + + Returns: + bool: Whether ``EMAHook`` has been initialized. + """ + if self.enabled_by_epoch: + return runner.epoch + 1 >= self.begin_epoch + else: + return runner.iter + 1 >= self.begin_iter diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py index 39e8ae52..571125bc 100644 --- a/tests/test_hook/test_ema_hook.py +++ b/tests/test_hook/test_ema_hook.py @@ -43,6 +43,16 @@ class ToyModel1(BaseModel, ToyModel): return super(BaseModel, self).forward(*args, **kwargs) +class ToyModel2(BaseModel, ToyModel): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(2, 1) + + def forward(self, *args, **kwargs): + return super(BaseModel, self).forward(*args, **kwargs) + + @DATASETS.register_module() class DummyDataset(Dataset): METAINFO = dict() # type: ignore @@ -171,3 +181,79 @@ class TestEMAHook(TestCase): custom_hooks=[dict(type='EMAHook')], experiment_name='test4') runner.test() + + # Test does not load ckpt strict_loadly. + # Test load checkpoint without ema_state_dict + runner = Runner( + model=ToyModel2(), + test_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + test_evaluator=evaluator, + test_cfg=dict(), + work_dir=self.temp_dir.name, + load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', strict_load=False)], + experiment_name='test5') + runner.test() + + # Test enable ema at 5 epochs. + runner = Runner( + model=model, + train_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=evaluator, + work_dir=self.temp_dir.name, + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), + val_cfg=dict(), + default_hooks=dict(logger=None), + custom_hooks=[dict(type='EMAHook', begin_epoch=5)], + experiment_name='test6') + runner.train() + state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth')) + self.assertNotIn('ema_state_dict', state_dict) + state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_5.pth')) + self.assertIn('ema_state_dict', state_dict) + + # Test enable ema at 5 iterations. + runner = Runner( + model=model, + train_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='DummyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=evaluator, + work_dir=self.temp_dir.name, + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=False, max_iters=10, val_interval=1), + val_cfg=dict(), + default_hooks=dict( + checkpoint=dict( + type='CheckpointHook', interval=1, by_epoch=False)), + custom_hooks=[dict(type='EMAHook', begin_iter=5)], + experiment_name='test7') + runner.train() + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth')) + self.assertNotIn('ema_state_dict', state_dict) + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth')) + self.assertIn('ema_state_dict', state_dict) -- GitLab