Skip to content
Snippets Groups Projects
Unverified Commit a6f52977 authored by takuoko's avatar takuoko Committed by GitHub
Browse files

[fix] EMAHook load state dict (#507)

* fix ema load_state_dict

* fix ema load_state_dict

* fix for test

* fix by review

* fix resume and keys
parent cfb884c1
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ from typing import Dict, Optional ...@@ -7,6 +7,7 @@ from typing import Dict, Optional
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmengine.registry import HOOKS, MODELS from mmengine.registry import HOOKS, MODELS
from mmengine.runner.checkpoint import _load_checkpoint_to_model
from .hook import DATA_BATCH, Hook from .hook import DATA_BATCH, Hook
...@@ -171,7 +172,7 @@ class EMAHook(Hook): ...@@ -171,7 +172,7 @@ class EMAHook(Hook):
Args: Args:
runner (Runner): The runner of the testing process. runner (Runner): The runner of the testing process.
""" """
if 'ema_state_dict' in checkpoint: if 'ema_state_dict' in checkpoint and runner._resume:
# The original model parameters are actually saved in ema # The original model parameters are actually saved in ema
# field swap the weights back to resume ema state. # field swap the weights back to resume ema state.
self._swap_ema_state_dict(checkpoint) self._swap_ema_state_dict(checkpoint)
...@@ -180,11 +181,13 @@ class EMAHook(Hook): ...@@ -180,11 +181,13 @@ class EMAHook(Hook):
# Support load checkpoint without ema state dict. # Support load checkpoint without ema state dict.
else: else:
print_log( if runner._resume:
'There is no `ema_state_dict` in checkpoint. ' print_log(
'`EMAHook` will make a copy of `state_dict` as the ' 'There is no `ema_state_dict` in checkpoint. '
'initial `ema_state_dict`', 'current', logging.WARNING) '`EMAHook` will make a copy of `state_dict` as the '
self.ema_model.module.load_state_dict( 'initial `ema_state_dict`', 'current', logging.WARNING)
_load_checkpoint_to_model(
self.ema_model.module,
copy.deepcopy(checkpoint['state_dict']), copy.deepcopy(checkpoint['state_dict']),
strict=self.strict_load) strict=self.strict_load)
......
...@@ -56,6 +56,16 @@ class ToyModel2(BaseModel, ToyModel): ...@@ -56,6 +56,16 @@ class ToyModel2(BaseModel, ToyModel):
return super(BaseModel, self).forward(*args, **kwargs) return super(BaseModel, self).forward(*args, **kwargs)
class ToyModel3(BaseModel, ToyModel):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(2, 2)
def forward(self, *args, **kwargs):
return super(BaseModel, self).forward(*args, **kwargs)
@DATASETS.register_module() @DATASETS.register_module()
class DummyDataset(Dataset): class DummyDataset(Dataset):
METAINFO = dict() # type: ignore METAINFO = dict() # type: ignore
...@@ -203,6 +213,25 @@ class TestEMAHook(TestCase): ...@@ -203,6 +213,25 @@ class TestEMAHook(TestCase):
experiment_name='test5') experiment_name='test5')
runner.test() runner.test()
# Test does not load ckpt strict_loadly.
# Test load checkpoint without ema_state_dict
# Test with different size head.
runner = Runner(
model=ToyModel3(),
test_dataloader=dict(
dataset=dict(type='DummyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
test_evaluator=evaluator,
test_cfg=dict(),
work_dir=self.temp_dir.name,
load_from=osp.join(self.temp_dir.name, 'epoch_2.pth'),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', strict_load=False)],
experiment_name='test5')
runner.test()
# Test enable ema at 5 epochs. # Test enable ema at 5 epochs.
runner = Runner( runner = Runner(
model=model, model=model,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment