From ab31e1936ea4d4905856cb21919cb6d65ecf3781 Mon Sep 17 00:00:00 2001 From: Yuan Liu <30762564+YuanLiuuuuuu@users.noreply.github.com> Date: Tue, 1 Mar 2022 15:49:44 +0800 Subject: [PATCH] [Feature]: Add param scheduler hook (#63) * [Feature]: Add param scheduler hook * [Fix]: update docstring and add assert_call to UT --- mmengine/hooks/__init__.py | 5 ++- mmengine/hooks/param_scheduler_hook.py | 41 ++++++++++++++++++++ tests/test_hook/test_param_scheduler_hook.py | 27 +++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 mmengine/hooks/param_scheduler_hook.py create mode 100644 tests/test_hook/test_param_scheduler_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index c6895bea..ff87dc67 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -2,5 +2,8 @@ from .hook import Hook from .iter_timer_hook import IterTimerHook from .sampler_seed_hook import DistSamplerSeedHook +from .param_scheduler_hook import ParamSchedulerHook -__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook'] +__all__ = [ + 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook' +] diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py new file mode 100644 index 00000000..1bbf610f --- /dev/null +++ b/mmengine/hooks/param_scheduler_hook.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Sequence + +from mmengine.data import BaseDataSample +from mmengine.registry import HOOKS +from .hook import Hook + + +@HOOKS.register_module() +class ParamSchedulerHook(Hook): + """A hook to update some hyper-parameters in optimizer, e.g learning rate + and momentum.""" + + def after_iter(self, + runner: object, + data_batch: Optional[Sequence[BaseDataSample]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Call step function for each scheduler after each iteration. + + Args: + runner (object): The runner of the training process. + data_batch (Sequence[BaseDataSample]): Data from dataloader. In + order to keep this interface consistent with other hooks, we + keep ``data_batch`` here. Defaults to None. + outputs (Sequence[BaseDataSample]): Outputs from model. In + order to keep this interface consistent with other hooks, we + keep ``data_batch`` here. Defaults to None. + """ + for scheduler in runner.schedulers: # type: ignore + if not scheduler.by_epoch: + scheduler.step() + + def after_epoch(self, runner: object) -> None: + """Call step function for each scheduler after each epoch. + + Args: + runner (object): The runner of the training process. + """ + for scheduler in runner.schedulers: # type: ignore + if scheduler.by_epoch: + scheduler.step() diff --git a/tests/test_hook/test_param_scheduler_hook.py b/tests/test_hook/test_param_scheduler_hook.py new file mode 100644 index 00000000..75f12c4a --- /dev/null +++ b/tests/test_hook/test_param_scheduler_hook.py @@ -0,0 +1,27 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +from mmengine.hooks import ParamSchedulerHook + + +class TestParamSchedulerHook: + + def test_after_iter(self): + Hook = ParamSchedulerHook() + Runner = Mock() + scheduler = Mock() + scheduler.step = Mock() + scheduler.by_epoch = False + Runner.schedulers = [scheduler] + Hook.after_iter(Runner) + scheduler.step.assert_called() + + def test_after_epoch(self): + Hook = ParamSchedulerHook() + Runner = Mock() + scheduler = Mock() + scheduler.step = Mock() + scheduler.by_epoch = True + Runner.schedulers = [scheduler] + Hook.after_epoch(Runner) + scheduler.step.assert_called() -- GitLab