diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index c6895bea72081daaff9c7199eaca3410c2d3bef3..ff87dc6765c1bfe8038a373332c2a34c5fc12867 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -2,5 +2,8 @@ from .hook import Hook from .iter_timer_hook import IterTimerHook from .sampler_seed_hook import DistSamplerSeedHook +from .param_scheduler_hook import ParamSchedulerHook -__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook'] +__all__ = [ + 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook' +] diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..1bbf610ff0405e3df8b07b0773cc468ed0fb1e3f --- /dev/null +++ b/mmengine/hooks/param_scheduler_hook.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class ParamSchedulerHook(Hook): + """A hook to update some hyper-parameters in optimizer, e.g learning rate + and momentum.""" + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Call step function for each scheduler after each iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. In + order to keep this interface consistent with other hooks, we + keep ``data_batch`` here. Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. In + order to keep this interface consistent with other hooks, we + keep ``data_batch`` here. Defaults to None. + """ + for scheduler in runner.schedulers: # type: ignore + if not scheduler.by_epoch: + scheduler.step() + + def after_epoch(self, runner: object) -> None: + """Call step function for each scheduler after each epoch. + + Args: + runner (object): The runner of the training process. + """ + for scheduler in runner.schedulers: # type: ignore + if scheduler.by_epoch: + scheduler.step() diff --git a/tests/test_hook/test_param_scheduler_hook.py b/tests/test_hook/test_param_scheduler_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..75f12c4a8a5c732f54def0a60a7981d1094bb67b --- /dev/null +++ b/tests/test_hook/test_param_scheduler_hook.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import ParamSchedulerHook + + +class TestParamSchedulerHook: + + def test_after_iter(self): + Hook = ParamSchedulerHook() + Runner = Mock() + scheduler = Mock() + scheduler.step = Mock() + scheduler.by_epoch = False + Runner.schedulers = [scheduler] + Hook.after_iter(Runner) + scheduler.step.assert_called() + + def test_after_epoch(self): + Hook = ParamSchedulerHook() + Runner = Mock() + scheduler = Mock() + scheduler.step = Mock() + scheduler.by_epoch = True + Runner.schedulers = [scheduler] + Hook.after_epoch(Runner) + scheduler.step.assert_called()