Skip to content
Snippets Groups Projects
Unverified Commit 6b47035f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix save scheduler state dict with optim wrapper (#375)

* fix save scheduler state dict with optim wrapper

* remove for loop and inherit TestParameterScheduler

* remove for loop and inherit TestParameterScheduler

* minor refine
parent 5b065b10
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -7,7 +8,7 @@ import torch.nn as nn ...@@ -7,7 +8,7 @@ import torch.nn as nn
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from torch.optim import Optimizer from torch.optim import Optimizer
from mmengine.logging import MessageHub, MMLogger from mmengine.logging import MessageHub, print_log
from mmengine.registry import OPTIM_WRAPPERS from mmengine.registry import OPTIM_WRAPPERS
from mmengine.utils import has_batch_norm from mmengine.utils import has_batch_norm
...@@ -106,7 +107,6 @@ class OptimWrapper: ...@@ -106,7 +107,6 @@ class OptimWrapper:
'If `clip_grad` is not None, it should be a `dict` ' 'If `clip_grad` is not None, it should be a `dict` '
'which is the arguments of `torch.nn.utils.clip_grad`') 'which is the arguments of `torch.nn.utils.clip_grad`')
self.clip_grad_kwargs = clip_grad self.clip_grad_kwargs = clip_grad
self.logger = MMLogger.get_current_instance()
# Used to update `grad_norm` log message. # Used to update `grad_norm` log message.
self.message_hub = MessageHub.get_current_instance() self.message_hub = MessageHub.get_current_instance()
self._inner_count = 0 self._inner_count = 0
...@@ -318,16 +318,20 @@ class OptimWrapper: ...@@ -318,16 +318,20 @@ class OptimWrapper:
self._inner_count = init_counts self._inner_count = init_counts
self._max_counts = max_counts self._max_counts = max_counts
if self._inner_count % self._accumulative_counts != 0: if self._inner_count % self._accumulative_counts != 0:
self.logger.warning( print_log(
'Resumed iteration number is not divisible by ' 'Resumed iteration number is not divisible by '
'`_accumulative_counts` in `GradientCumulativeOptimizerHook`, ' '`_accumulative_counts` in `GradientCumulativeOptimizerHook`, '
'which means the gradient of some iterations is lost and the ' 'which means the gradient of some iterations is lost and the '
'result may be influenced slightly.') 'result may be influenced slightly.',
logger='current',
level=logging.WARNING)
if has_batch_norm(model) and self._accumulative_counts > 1: if has_batch_norm(model) and self._accumulative_counts > 1:
self.logger.warning( print_log(
'Gradient accumulative may slightly decrease ' 'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.') 'performance because the model has BatchNorm layers.',
logger='current',
level=logging.WARNING)
# Remainder of `_max_counts` divided by `_accumulative_counts` # Remainder of `_max_counts` divided by `_accumulative_counts`
self._remainder_counts = self._max_counts % self._accumulative_counts self._remainder_counts = self._max_counts % self._accumulative_counts
......
...@@ -1025,13 +1025,15 @@ class OneCycleParamScheduler(_ParamScheduler): ...@@ -1025,13 +1025,15 @@ class OneCycleParamScheduler(_ParamScheduler):
else: else:
return [param] * len(optimizer.param_groups) return [param] * len(optimizer.param_groups)
def _annealing_cos(self, start, end, pct): @staticmethod
def _annealing_cos(start, end, pct):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1 cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct): @staticmethod
def _annealing_linear(start, end, pct):
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to """Linearly anneal from `start` to `end` as pct goes from 0.0 to
1.0.""" 1.0."""
return (end - start) * pct + start return (end - start) * pct + start
......
...@@ -64,7 +64,6 @@ class TestOptimWrapper(MultiProcessTestCase): ...@@ -64,7 +64,6 @@ class TestOptimWrapper(MultiProcessTestCase):
self.assertIs(optim_wrapper.optimizer, self.optimizer) self.assertIs(optim_wrapper.optimizer, self.optimizer)
self.assertIsNone(optim_wrapper.clip_grad_kwargs) self.assertIsNone(optim_wrapper.clip_grad_kwargs)
self.assertEqual(optim_wrapper._accumulative_counts, 1) self.assertEqual(optim_wrapper._accumulative_counts, 1)
self.assertIs(optim_wrapper.logger, self.logger)
self.assertIs(optim_wrapper.message_hub, self.message_hub) self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0) self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1) self.assertEqual(optim_wrapper._max_counts, -1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
import os.path as osp
import tempfile
from unittest import TestCase from unittest import TestCase
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from mmengine.optim import OptimWrapper
# yapf: disable # yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler, from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
...@@ -55,6 +58,7 @@ class TestParameterScheduler(TestCase): ...@@ -55,6 +58,7 @@ class TestParameterScheduler(TestCase):
lr=lr, lr=lr,
momentum=momentum, momentum=momentum,
weight_decay=weight_decay) weight_decay=weight_decay)
self.temp_dir = tempfile.TemporaryDirectory()
def test_base_scheduler_step(self): def test_base_scheduler_step(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
...@@ -408,7 +412,10 @@ class TestParameterScheduler(TestCase): ...@@ -408,7 +412,10 @@ class TestParameterScheduler(TestCase):
scheduler.optimizer.step() scheduler.optimizer.step()
scheduler.step() scheduler.step()
scheduler_copy = construct2() scheduler_copy = construct2()
scheduler_copy.load_state_dict(scheduler.state_dict()) torch.save(scheduler.state_dict(),
osp.join(self.temp_dir.name, 'tmp.pth'))
state_dict = torch.load(osp.join(self.temp_dir.name, 'tmp.pth'))
scheduler_copy.load_state_dict(state_dict)
for key in scheduler.__dict__.keys(): for key in scheduler.__dict__.keys():
if key != 'optimizer': if key != 'optimizer':
self.assertEqual(scheduler.__dict__[key], self.assertEqual(scheduler.__dict__[key],
...@@ -743,3 +750,10 @@ class TestParameterScheduler(TestCase): ...@@ -743,3 +750,10 @@ class TestParameterScheduler(TestCase):
param_name='lr', param_name='lr',
total_steps=10, total_steps=10,
anneal_strategy='a') anneal_strategy='a')
class TestParameterSchedulerOptimWrapper(TestParameterScheduler):
def setUp(self):
super().setUp()
self.optimizer = OptimWrapper(optimizer=self.optimizer)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment