Skip to content
Snippets Groups Projects
Unverified Commit ab31e193 authored by Yuan Liu's avatar Yuan Liu Committed by GitHub
Browse files

[Feature]: Add param scheduler hook (#63)

* [Feature]: Add param scheduler hook

* [Fix]: update docstring and add assert_call to UT
parent 2d3e9124
No related branches found
No related tags found
No related merge requests found
......@@ -2,5 +2,8 @@
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .sampler_seed_hook import DistSamplerSeedHook
from .param_scheduler_hook import ParamSchedulerHook
__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook']
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
from .hook import Hook
@HOOKS.register_module()
class ParamSchedulerHook(Hook):
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
and momentum."""
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. In
order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. In
order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
for scheduler in runner.schedulers: # type: ignore
if not scheduler.by_epoch:
scheduler.step()
def after_epoch(self, runner: object) -> None:
"""Call step function for each scheduler after each epoch.
Args:
runner (object): The runner of the training process.
"""
for scheduler in runner.schedulers: # type: ignore
if scheduler.by_epoch:
scheduler.step()
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import ParamSchedulerHook
class TestParamSchedulerHook:
def test_after_iter(self):
Hook = ParamSchedulerHook()
Runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
Runner.schedulers = [scheduler]
Hook.after_iter(Runner)
scheduler.step.assert_called()
def test_after_epoch(self):
Hook = ParamSchedulerHook()
Runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
Runner.schedulers = [scheduler]
Hook.after_epoch(Runner)
scheduler.step.assert_called()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment