Newer
Older
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS
from .hook import DATA_BATCH, Hook
@HOOKS.register_module()
class EMAHook(Hook):
"""A Hook to apply Exponential Moving Average (EMA) on the model during
training.
Note:
- EMAHook takes priority over CheckpointHook.
- The original model parameters are actually saved in ema field after
train.
Args:
ema_type (str): The type of EMA strategy to use. You can find the
supported strategies in ``mmengine.model.averaged_model``.
Defaults to 'ExponentialMovingAverage'
"""
priority = 'NORMAL'
def __init__(self, ema_type: str = 'ExponentialMovingAverage', **kwargs):
self.ema_cfg = dict(type=ema_type, **kwargs)
def before_run(self, runner) -> None:
"""Create an ema copy of the model."""
model = runner.model
if is_model_wrapper(model):
model = model.module
self.src_model = model
self.ema_model = MODELS.build(
self.ema_cfg, default_args=dict(model=self.src_model))
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Update ema parameter."""
self.ema_model.update_parameters(self.src_model)
def before_val_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
validation."""
self._swap_ema_parameters()
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after
validation."""
self._swap_ema_parameters()
def before_test_epoch(self, runner) -> None:
"""We load parameter values from ema model to source model before
test."""
self._swap_ema_parameters()
def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""We recover source model's parameter from ema model after test."""
self._swap_ema_parameters()
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""Save ema parameters to checkpoint."""
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
# Save ema parameters to the source model's state dict so that we can
# directly load the averaged model weights for deployment.
# Swapping the state_dict key-values instead of swapping model
# parameters because the state_dict is a shallow copy of model
# parameters.
self._swap_ema_state_dict(checkpoint)
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint."""
if 'ema_state_dict' in checkpoint:
# The original model parameters are actually saved in ema field.
# swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']))
def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model."""
avg_param = (
itertools.chain(self.ema_model.module.parameters(),
self.ema_model.module.buffers())
if self.ema_model.update_buffers else
self.ema_model.module.parameters())
src_param = (
itertools.chain(self.src_model.parameters(),
self.src_model.buffers())
if self.ema_model.update_buffers else self.src_model.parameters())
for p_avg, p_src in zip(avg_param, src_param):
tmp = p_avg.data.clone()
p_avg.data.copy_(p_src.data)
p_src.data.copy_(tmp)
def _swap_ema_state_dict(self, checkpoint):
"""Swap the state dict values of model with ema_model."""
model_state = checkpoint['state_dict']
ema_state = checkpoint['ema_state_dict']
for k in ema_state:
if k[:7] == 'module.':
tmp = ema_state[k]
ema_state[k] = model_state[k[7:]]
model_state[k[7:]] = tmp