From a07a063306f24fb0f4d59af19da4b33326cf3286 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 8 Aug 2022 20:34:16 +0800 Subject: [PATCH] [Enhance] Add build function for scheduler. (#372) * add build function for scheduler * add unit test add unit test * handle convert_to_iter in build_scheduler_from_cfg * restore deleted code * format import * fix lint --- mmengine/registry/__init__.py | 5 +- mmengine/registry/build_functions.py | 82 +++++++- mmengine/registry/root.py | 6 +- mmengine/runner/runner.py | 28 +-- tests/test_registry/test_build_functions.py | 213 ++++++++++++++++++++ tests/test_registry/test_registry.py | 184 +---------------- tests/test_runner/test_runner.py | 5 - 7 files changed, 305 insertions(+), 218 deletions(-) create mode 100644 tests/test_registry/test_build_functions.py diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 8be89e69..1e1de6bf 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .build_functions import (build_from_cfg, build_model_from_cfg, - build_runner_from_cfg) + build_runner_from_cfg, build_scheduler_from_cfg) from .default_scope import DefaultScope from .registry import Registry from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, @@ -17,5 +17,6 @@ __all__ = [ 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR', 'DefaultScope', 'traverse_registry_tree', 'count_registered_modules', - 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg' + 'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg', + 'build_scheduler_from_cfg' ] diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 0f4dbb96..78348086 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -1,12 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect import logging -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union + +import torch.nn as nn from ..config import Config, ConfigDict from ..utils import ManagerMixin from .registry import Registry +if TYPE_CHECKING: + from ..optim.scheduler import _ParamScheduler + from ..runner import Runner + def build_from_cfg( cfg: Union[dict, ConfigDict, Config], @@ -131,7 +137,7 @@ def build_from_cfg( def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], - registry: Registry) -> Any: + registry: Registry) -> 'Runner': """Build a Runner object. Examples: >>> from mmengine.registry import Registry, build_runner_from_cfg @@ -203,7 +209,11 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], f'{cls_location}.py: {e}') -def build_model_from_cfg(cfg, registry, default_args=None): +def build_model_from_cfg( + cfg: Union[dict, ConfigDict, Config], + registry: Registry, + default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \ + nn.Module: """Build a PyTorch model from config dict(s). Different from ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built. @@ -226,3 +236,69 @@ def build_model_from_cfg(cfg, registry, default_args=None): return Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) + + +def build_scheduler_from_cfg( + cfg: Union[dict, ConfigDict, Config], + registry: Registry, + default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \ + '_ParamScheduler': + """Builds a ``ParamScheduler`` instance from config. + + ``ParamScheduler`` supports building instance by its constructor or + method ``build_iter_from_epoch``. Therefore, its registry needs a build + function to handle both cases. + + Args: + cfg (dict or ConfigDict or Config): Config dictionary. If it contains + the key ``convert_to_iter_based``, instance will be built by method + ``convert_to_iter_based``, otherwise instance will be built by its + constructor. + registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry. + default_args (dict or ConfigDict or Config, optional): Default + initialization arguments. It must contain key ``optimizer``. If + ``convert_to_iter_based`` is defined in ``cfg``, it must + additionally contain key ``epoch_length``. Defaults to None. + + Returns: + object: The constructed ``ParamScheduler``. + """ + assert isinstance( + cfg, + (dict, ConfigDict, Config + )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' + assert isinstance( + registry, Registry), ('registry should be a mmengine.Registry object', + f'but got {type(registry)}') + + args = cfg.copy() + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + scope = args.pop('_scope_', None) + with registry.switch_scope_and_registry(scope) as registry: + convert_to_iter = args.pop('convert_to_iter_based', False) + if convert_to_iter: + scheduler_type = args.pop('type') + assert 'epoch_length' in args and args.get('by_epoch', True), ( + 'Only epoch-based parameter scheduler can be converted to ' + 'iter-based, and `epoch_length` should be set') + if isinstance(scheduler_type, str): + scheduler_cls = registry.get(scheduler_type) + if scheduler_cls is None: + raise KeyError( + f'{scheduler_type} is not in the {registry.name} ' + 'registry. Please check whether the value of ' + f'`{scheduler_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 + ) + elif inspect.isclass(scheduler_type): + scheduler_cls = scheduler_type + else: + raise TypeError('type must be a str or valid type, but got ' + f'{type(scheduler_type)}') + return scheduler_cls.build_iter_from_epoch( # type: ignore + **args) + else: + args.pop('epoch_length', None) + return build_from_cfg(args, registry) diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index a63686f9..838465b4 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -6,7 +6,8 @@ More datails can be found at https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. """ -from mmengine.registry import build_model_from_cfg, build_runner_from_cfg +from .build_functions import (build_model_from_cfg, build_runner_from_cfg, + build_scheduler_from_cfg) from .registry import Registry # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` @@ -37,7 +38,8 @@ OPTIM_WRAPPERS = Registry('optim_wrapper') # manage constructors that customize the optimization hyperparameters. OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor') # mangage all kinds of parameter schedulers like `MultiStepLR` -PARAM_SCHEDULERS = Registry('parameter scheduler') +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', build_func=build_scheduler_from_cfg) # manage all kinds of metrics METRICS = Registry('metric') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 3e416842..b7b2f8ef 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1134,34 +1134,16 @@ class Runner: f'The `end` of {_scheduler["type"]} is not set. ' 'Use the max epochs/iters of train loop as default.') - convert_to_iter = _scheduler.pop('convert_to_iter_based', - False) - if convert_to_iter: - assert _scheduler.get( - 'by_epoch', - True), ('only epoch-based parameter scheduler can be ' - 'converted to iter-based') - assert isinstance(self._train_loop, BaseLoop), \ - 'Scheduler can only be converted to iter-based ' \ - 'when train loop is built.' - cls = PARAM_SCHEDULERS.get(_scheduler.pop('type')) - param_schedulers.append( - cls.build_iter_from_epoch( # type: ignore + param_schedulers.append( + PARAM_SCHEDULERS.build( + _scheduler, + default_args=dict( optimizer=optim_wrapper, - **_scheduler, - epoch_length=len( - self.train_dataloader), # type: ignore - )) - else: - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict(optimizer=optim_wrapper))) + epoch_length=len(self.train_dataloader)))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' f'but got {scheduler}') - return param_schedulers def build_param_scheduler( diff --git a/tests/test_registry/test_build_functions.py b/tests/test_registry/test_build_functions.py new file mode 100644 index 00000000..618f99e3 --- /dev/null +++ b/tests/test_registry/test_build_functions.py @@ -0,0 +1,213 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch.nn as nn +from torch.optim import SGD + +from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin, + Registry, build_from_cfg, build_model_from_cfg) + + +@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) +def test_build_from_cfg(cfg_type): + BACKBONES = Registry('backbone') + + @BACKBONES.register_module() + class ResNet: + + def __init__(self, depth, stages=4): + self.depth = depth + self.stages = stages + + @BACKBONES.register_module() + class ResNeXt: + + def __init__(self, depth, stages=4): + self.depth = depth + self.stages = stages + + # test `cfg` parameter + # `cfg` should be a dict, ConfigDict or Config object + with pytest.raises( + TypeError, + match=('cfg should be a dict, ConfigDict or Config, but got ' + "<class 'str'>")): + cfg = 'ResNet' + model = build_from_cfg(cfg, BACKBONES) + + # `cfg` is a dict, ConfigDict or Config object + cfg = cfg_type(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # `cfg` is a dict but it does not contain the key "type" + with pytest.raises(KeyError, match='must contain the key "type"'): + cfg = dict(depth=50, stages=4) + cfg = cfg_type(cfg) + model = build_from_cfg(cfg, BACKBONES) + + # cfg['type'] should be a str or class + with pytest.raises( + TypeError, + match="type must be a str or valid type, but got <class 'int'>"): + cfg = dict(type=1000) + cfg = cfg_type(cfg) + model = build_from_cfg(cfg, BACKBONES) + + cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNeXt) + assert model.depth == 50 and model.stages == 3 + + cfg = cfg_type(dict(type=ResNet, depth=50)) + model = build_from_cfg(cfg, BACKBONES) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # non-registered class + with pytest.raises(KeyError, match='VGG is not in the backbone registry'): + cfg = cfg_type(dict(type='VGG')) + model = build_from_cfg(cfg, BACKBONES) + + # `cfg` contains unexpected arguments + with pytest.raises(TypeError): + cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) + model = build_from_cfg(cfg, BACKBONES) + + # test `default_args` parameter + cfg = cfg_type(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 3 + + # default_args must be a dict or None + with pytest.raises(TypeError): + cfg = cfg_type(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, BACKBONES, default_args=1) + + # cfg or default_args should contain the key "type" + with pytest.raises(KeyError, match='must contain the key "type"'): + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + + # "type" defined using default_args + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + cfg = cfg_type(dict(depth=50)) + model = build_from_cfg( + cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + # test `registry` parameter + # incorrect registry type + with pytest.raises( + TypeError, + match=('registry must be a mmengine.Registry object, but got ' + "<class 'str'>")): + cfg = cfg_type(dict(type='ResNet', depth=50)) + model = build_from_cfg(cfg, 'BACKBONES') + + VISUALIZER = Registry('visualizer') + + @VISUALIZER.register_module() + class Visualizer(ManagerMixin): + + def __init__(self, name): + super().__init__(name) + + with pytest.raises(RuntimeError): + Visualizer.get_current_instance() + cfg = dict(type='Visualizer', name='visualizer') + build_from_cfg(cfg, VISUALIZER) + Visualizer.get_current_instance() + + +def test_build_model_from_cfg(): + BACKBONES = Registry('backbone', build_func=build_model_from_cfg) + + @BACKBONES.register_module() + class ResNet(nn.Module): + + def __init__(self, depth, stages=4): + super().__init__() + self.depth = depth + self.stages = stages + + def forward(self, x): + return x + + @BACKBONES.register_module() + class ResNeXt(nn.Module): + + def __init__(self, depth, stages=4): + super().__init__() + self.depth = depth + self.stages = stages + + def forward(self, x): + return x + + cfg = dict(type='ResNet', depth=50) + model = BACKBONES.build(cfg) + assert isinstance(model, ResNet) + assert model.depth == 50 and model.stages == 4 + + cfg = dict(type='ResNeXt', depth=50, stages=3) + model = BACKBONES.build(cfg) + assert isinstance(model, ResNeXt) + assert model.depth == 50 and model.stages == 3 + + cfg = [ + dict(type='ResNet', depth=50), + dict(type='ResNeXt', depth=50, stages=3) + ] + model = BACKBONES.build(cfg) + assert isinstance(model, nn.Sequential) + assert isinstance(model[0], ResNet) + assert model[0].depth == 50 and model[0].stages == 4 + assert isinstance(model[1], ResNeXt) + assert model[1].depth == 50 and model[1].stages == 3 + + # test inherit `build_func` from parent + NEW_MODELS = Registry('models', parent=BACKBONES, scope='new') + assert NEW_MODELS.build_func is build_model_from_cfg + + # test specify `build_func` + def pseudo_build(cfg): + return cfg + + NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build) + assert NEW_MODELS.build_func is pseudo_build + + +def test_build_sheduler_from_cfg(): + model = nn.Conv2d(1, 1, 1) + optimizer = SGD(model.parameters(), lr=0.1) + cfg = dict( + type='LinearParamScheduler', + optimizer=optimizer, + param_name='lr', + begin=0, + end=100) + sheduler = PARAM_SCHEDULERS.build(cfg) + assert sheduler.begin == 0 + assert sheduler.end == 100 + + cfg = dict( + type='LinearParamScheduler', + convert_to_iter_based=True, + optimizer=optimizer, + param_name='lr', + begin=0, + end=100, + epoch_length=10) + + sheduler = PARAM_SCHEDULERS.build(cfg) + assert sheduler.begin == 0 + assert sheduler.end == 1000 diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 41ad4684..74e51e22 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -2,12 +2,9 @@ import time import pytest -import torch.nn as nn from mmengine.config import Config, ConfigDict # type: ignore -from mmengine.registry import (DefaultScope, Registry, build_from_cfg, - build_model_from_cfg) -from mmengine.utils import ManagerMixin +from mmengine.registry import DefaultScope, Registry, build_from_cfg class TestRegistry: @@ -476,182 +473,3 @@ class TestRegistry: "<locals>.Munchkin'>") repr_str += '})' assert repr(CATS) == repr_str - - -@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config]) -def test_build_from_cfg(cfg_type): - BACKBONES = Registry('backbone') - - @BACKBONES.register_module() - class ResNet: - - def __init__(self, depth, stages=4): - self.depth = depth - self.stages = stages - - @BACKBONES.register_module() - class ResNeXt: - - def __init__(self, depth, stages=4): - self.depth = depth - self.stages = stages - - # test `cfg` parameter - # `cfg` should be a dict, ConfigDict or Config object - with pytest.raises( - TypeError, - match=('cfg should be a dict, ConfigDict or Config, but got ' - "<class 'str'>")): - cfg = 'ResNet' - model = build_from_cfg(cfg, BACKBONES) - - # `cfg` is a dict, ConfigDict or Config object - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # `cfg` is a dict but it does not contain the key "type" - with pytest.raises(KeyError, match='must contain the key "type"'): - cfg = dict(depth=50, stages=4) - cfg = cfg_type(cfg) - model = build_from_cfg(cfg, BACKBONES) - - # cfg['type'] should be a str or class - with pytest.raises( - TypeError, - match="type must be a str or valid type, but got <class 'int'>"): - cfg = dict(type=1000) - cfg = cfg_type(cfg) - model = build_from_cfg(cfg, BACKBONES) - - cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3)) - model = build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNeXt) - assert model.depth == 50 and model.stages == 3 - - cfg = cfg_type(dict(type=ResNet, depth=50)) - model = build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # non-registered class - with pytest.raises(KeyError, match='VGG is not in the backbone registry'): - cfg = cfg_type(dict(type='VGG')) - model = build_from_cfg(cfg, BACKBONES) - - # `cfg` contains unexpected arguments - with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', non_existing_arg=50)) - model = build_from_cfg(cfg, BACKBONES) - - # test `default_args` parameter - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3))) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 3 - - # default_args must be a dict or None - with pytest.raises(TypeError): - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, BACKBONES, default_args=1) - - # cfg or default_args should contain the key "type" - with pytest.raises(KeyError, match='must contain the key "type"'): - cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) - - # "type" defined using default_args - cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # test `registry` parameter - # incorrect registry type - with pytest.raises( - TypeError, - match=('registry must be a mmengine.Registry object, but got ' - "<class 'str'>")): - cfg = cfg_type(dict(type='ResNet', depth=50)) - model = build_from_cfg(cfg, 'BACKBONES') - - VISUALIZER = Registry('visualizer') - - @VISUALIZER.register_module() - class Visualizer(ManagerMixin): - - def __init__(self, name): - super().__init__(name) - - with pytest.raises(RuntimeError): - Visualizer.get_current_instance() - cfg = dict(type='Visualizer', name='visualizer') - build_from_cfg(cfg, VISUALIZER) - Visualizer.get_current_instance() - - -def test_build_model_from_cfg(): - BACKBONES = Registry('backbone', build_func=build_model_from_cfg) - - @BACKBONES.register_module() - class ResNet(nn.Module): - - def __init__(self, depth, stages=4): - super().__init__() - self.depth = depth - self.stages = stages - - def forward(self, x): - return x - - @BACKBONES.register_module() - class ResNeXt(nn.Module): - - def __init__(self, depth, stages=4): - super().__init__() - self.depth = depth - self.stages = stages - - def forward(self, x): - return x - - cfg = dict(type='ResNet', depth=50) - model = BACKBONES.build(cfg) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - cfg = dict(type='ResNeXt', depth=50, stages=3) - model = BACKBONES.build(cfg) - assert isinstance(model, ResNeXt) - assert model.depth == 50 and model.stages == 3 - - cfg = [ - dict(type='ResNet', depth=50), - dict(type='ResNeXt', depth=50, stages=3) - ] - model = BACKBONES.build(cfg) - assert isinstance(model, nn.Sequential) - assert isinstance(model[0], ResNet) - assert model[0].depth == 50 and model[0].stages == 4 - assert isinstance(model[1], ResNeXt) - assert model[1].depth == 50 and model[1].stages == 3 - - # test inherit `build_func` from parent - NEW_MODELS = Registry('models', parent=BACKBONES, scope='new') - assert NEW_MODELS.build_func is build_model_from_cfg - - # test specify `build_func` - def pseudo_build(cfg): - return cfg - - NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build) - assert NEW_MODELS.build_func is pseudo_build diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index f382bc3b..ebceb5ce 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -993,11 +993,6 @@ class TestRunner(TestCase): # 5.1 train loop should be built before converting scheduler cfg = dict( type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True) - with self.assertRaisesRegex( - AssertionError, - 'Scheduler can only be converted to iter-based when ' - 'train loop is built.'): - runner.build_param_scheduler(cfg) # 5.2 convert epoch-based to iter-based scheduler cfg = dict( -- GitLab