# Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import Mock import pytest from mmengine.hooks import ParamSchedulerHook class TestParamSchedulerHook: error_msg = ('runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler') def test_after_iter(self): # runner.param_schedulers should be a list or dict with pytest.raises(TypeError, match=self.error_msg): hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = False runner.param_schedulers = scheduler hook.after_train_iter(runner, 0) scheduler.step.assert_called() # runner.param_schedulers is a list of schedulers hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = False runner.param_schedulers = [scheduler] hook.after_train_iter(runner, 0) scheduler.step.assert_called() # runner.param_schedulers is a dict containing list of schedulers scheduler1 = Mock() scheduler1.step = Mock() scheduler1.by_epoch = False scheduler2 = Mock() scheduler2.step = Mock() scheduler2.by_epoch = False runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_train_epoch(runner) hook.after_train_iter(runner, 0) scheduler2.step.assert_called() def test_after_epoch(self): # runner.param_schedulers should be a list or dict with pytest.raises(TypeError, match=self.error_msg): hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True runner.param_schedulers = scheduler hook.after_train_iter(runner, 0) scheduler.step.assert_called() # runner.param_schedulers is a list of schedulers hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True runner.param_schedulers = [scheduler] hook.after_train_epoch(runner) scheduler.step.assert_called() # runner.param_schedulers is a dict containing list of schedulers scheduler1 = Mock() scheduler1.step = Mock() scheduler1.by_epoch = True scheduler2 = Mock() scheduler2.step = Mock() scheduler2.by_epoch = True runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_train_epoch(runner) scheduler1.step.assert_called() scheduler2.step.assert_called()