Skip to content
Snippets Groups Projects
Unverified Commit 38e78d55 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix ema hook and add unit test (#327)

* Fix ema hook and add unit test

* save state_dict of ema.module

save state_dict of ema.module

* replace warning.warn with MMLogger.warn

* fix as comment

* fix bug

* fix bug
parent 9c55b430
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import itertools import itertools
import logging
from typing import Dict, Optional from typing import Dict, Optional
from mmengine.logging import print_log
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS from mmengine.registry import HOOKS, MODELS
from .hook import DATA_BATCH, Hook from .hook import DATA_BATCH, Hook
...@@ -80,10 +83,21 @@ class EMAHook(Hook): ...@@ -80,10 +83,21 @@ class EMAHook(Hook):
def after_load_checkpoint(self, runner, checkpoint: dict) -> None: def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""Resume ema parameters from checkpoint.""" """Resume ema parameters from checkpoint."""
# The original model parameters are actually saved in ema field.
# swap the weights back to resume ema state. if 'ema_state_dict' in checkpoint:
self._swap_ema_state_dict(checkpoint) # The original model parameters are actually saved in ema field.
self.ema_model.load_state_dict(checkpoint['ema_state_dict']) # swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint)
self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
# Support load checkpoint without ema state dict.
else:
print_log(
'There is no `ema_state_dict` in checkpoint. '
'`EMAHook` will make a copy of `state_dict` as the '
'initial `ema_state_dict`', 'current', logging.WARNING)
self.ema_model.module.load_state_dict(
copy.deepcopy(checkpoint['state_dict']))
def _swap_ema_parameters(self) -> None: def _swap_ema_parameters(self) -> None:
"""Swap the parameter of model with ema_model.""" """Swap the parameter of model with ema_model."""
......
...@@ -149,3 +149,25 @@ class TestEMAHook(TestCase): ...@@ -149,3 +149,25 @@ class TestEMAHook(TestCase):
custom_hooks=[dict(type='EMAHook')], custom_hooks=[dict(type='EMAHook')],
experiment_name='test3') experiment_name='test3')
runner.test() runner.test()
# Test load checkpoint without ema_state_dict
ckpt = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
ckpt.pop('ema_state_dict')
torch.save(ckpt,
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
runner = Runner(
model=DummyWrapper(ToyModel()),
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name,
'without_ema_state_dict.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook')],
experiment_name='test4')
runner.test()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment