From a07a063306f24fb0f4d59af19da4b33326cf3286 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Mon, 8 Aug 2022 20:34:16 +0800
Subject: [PATCH] [Enhance] Add build function for scheduler. (#372)

* add build function for scheduler

* add unit test

add unit test

* handle convert_to_iter in build_scheduler_from_cfg

* restore deleted code

* format import

* fix lint
---
 mmengine/registry/__init__.py               |   5 +-
 mmengine/registry/build_functions.py        |  82 +++++++-
 mmengine/registry/root.py                   |   6 +-
 mmengine/runner/runner.py                   |  28 +--
 tests/test_registry/test_build_functions.py | 213 ++++++++++++++++++++
 tests/test_registry/test_registry.py        | 184 +----------------
 tests/test_runner/test_runner.py            |   5 -
 7 files changed, 305 insertions(+), 218 deletions(-)
 create mode 100644 tests/test_registry/test_build_functions.py

diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index 8be89e69..1e1de6bf 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .build_functions import (build_from_cfg, build_model_from_cfg,
-                              build_runner_from_cfg)
+                              build_runner_from_cfg, build_scheduler_from_cfg)
 from .default_scope import DefaultScope
 from .registry import Registry
 from .root import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
@@ -17,5 +17,6 @@ __all__ = [
     'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
     'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'EVALUATOR',
     'DefaultScope', 'traverse_registry_tree', 'count_registered_modules',
-    'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg'
+    'build_model_from_cfg', 'build_runner_from_cfg', 'build_from_cfg',
+    'build_scheduler_from_cfg'
 ]
diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py
index 0f4dbb96..78348086 100644
--- a/mmengine/registry/build_functions.py
+++ b/mmengine/registry/build_functions.py
@@ -1,12 +1,18 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import inspect
 import logging
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import torch.nn as nn
 
 from ..config import Config, ConfigDict
 from ..utils import ManagerMixin
 from .registry import Registry
 
+if TYPE_CHECKING:
+    from ..optim.scheduler import _ParamScheduler
+    from ..runner import Runner
+
 
 def build_from_cfg(
         cfg: Union[dict, ConfigDict, Config],
@@ -131,7 +137,7 @@ def build_from_cfg(
 
 
 def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
-                          registry: Registry) -> Any:
+                          registry: Registry) -> 'Runner':
     """Build a Runner object.
     Examples:
         >>> from mmengine.registry import Registry, build_runner_from_cfg
@@ -203,7 +209,11 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
                 f'{cls_location}.py: {e}')
 
 
-def build_model_from_cfg(cfg, registry, default_args=None):
+def build_model_from_cfg(
+        cfg: Union[dict, ConfigDict, Config],
+        registry: Registry,
+        default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
+        nn.Module:
     """Build a PyTorch model from config dict(s). Different from
     ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
 
@@ -226,3 +236,69 @@ def build_model_from_cfg(cfg, registry, default_args=None):
         return Sequential(*modules)
     else:
         return build_from_cfg(cfg, registry, default_args)
+
+
+def build_scheduler_from_cfg(
+        cfg: Union[dict, ConfigDict, Config],
+        registry: Registry,
+        default_args: Optional[Union[dict, ConfigDict, Config]] = None) -> \
+        '_ParamScheduler':
+    """Builds a ``ParamScheduler`` instance from config.
+
+    ``ParamScheduler`` supports building instance by its constructor or
+    method ``build_iter_from_epoch``. Therefore, its registry needs a build
+    function to handle both cases.
+
+    Args:
+        cfg (dict or ConfigDict or Config): Config dictionary. If it contains
+            the key ``convert_to_iter_based``, instance will be built by method
+            ``convert_to_iter_based``, otherwise instance will be built by its
+            constructor.
+        registry (:obj:`Registry`): The ``PARAM_SCHEDULERS`` registry.
+        default_args (dict or ConfigDict or Config, optional): Default
+            initialization arguments. It must contain key ``optimizer``. If
+            ``convert_to_iter_based`` is defined in ``cfg``, it must
+            additionally contain key ``epoch_length``. Defaults to None.
+
+    Returns:
+        object: The constructed ``ParamScheduler``.
+    """
+    assert isinstance(
+        cfg,
+        (dict, ConfigDict, Config
+         )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
+    assert isinstance(
+        registry, Registry), ('registry should be a mmengine.Registry object',
+                              f'but got {type(registry)}')
+
+    args = cfg.copy()
+    if default_args is not None:
+        for name, value in default_args.items():
+            args.setdefault(name, value)
+    scope = args.pop('_scope_', None)
+    with registry.switch_scope_and_registry(scope) as registry:
+        convert_to_iter = args.pop('convert_to_iter_based', False)
+        if convert_to_iter:
+            scheduler_type = args.pop('type')
+            assert 'epoch_length' in args and args.get('by_epoch', True), (
+                'Only epoch-based parameter scheduler can be converted to '
+                'iter-based, and `epoch_length` should be set')
+            if isinstance(scheduler_type, str):
+                scheduler_cls = registry.get(scheduler_type)
+                if scheduler_cls is None:
+                    raise KeyError(
+                        f'{scheduler_type} is not in the {registry.name} '
+                        'registry. Please check whether the value of '
+                        f'`{scheduler_type}` is correct or it was registered '
+                        'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules'  # noqa: E501
+                    )
+            elif inspect.isclass(scheduler_type):
+                scheduler_cls = scheduler_type
+            else:
+                raise TypeError('type must be a str or valid type, but got '
+                                f'{type(scheduler_type)}')
+            return scheduler_cls.build_iter_from_epoch(  # type: ignore
+                **args)
+        else:
+            args.pop('epoch_length', None)
+            return build_from_cfg(args, registry)
diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py
index a63686f9..838465b4 100644
--- a/mmengine/registry/root.py
+++ b/mmengine/registry/root.py
@@ -6,7 +6,8 @@ More datails can be found at
 https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
 """
 
-from mmengine.registry import build_model_from_cfg, build_runner_from_cfg
+from .build_functions import (build_model_from_cfg, build_runner_from_cfg,
+                              build_scheduler_from_cfg)
 from .registry import Registry
 
 # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
@@ -37,7 +38,8 @@ OPTIM_WRAPPERS = Registry('optim_wrapper')
 # manage constructors that customize the optimization hyperparameters.
 OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor')
 # mangage all kinds of parameter schedulers like `MultiStepLR`
-PARAM_SCHEDULERS = Registry('parameter scheduler')
+PARAM_SCHEDULERS = Registry(
+    'parameter scheduler', build_func=build_scheduler_from_cfg)
 
 # manage all kinds of metrics
 METRICS = Registry('metric')
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 3e416842..b7b2f8ef 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -1134,34 +1134,16 @@ class Runner:
                         f'The `end` of {_scheduler["type"]} is not set. '
                         'Use the max epochs/iters of train loop as default.')
 
-                convert_to_iter = _scheduler.pop('convert_to_iter_based',
-                                                 False)
-                if convert_to_iter:
-                    assert _scheduler.get(
-                        'by_epoch',
-                        True), ('only epoch-based parameter scheduler can be '
-                                'converted to iter-based')
-                    assert isinstance(self._train_loop, BaseLoop), \
-                        'Scheduler can only be converted to iter-based ' \
-                        'when train loop is built.'
-                    cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
-                    param_schedulers.append(
-                        cls.build_iter_from_epoch(  # type: ignore
+                param_schedulers.append(
+                    PARAM_SCHEDULERS.build(
+                        _scheduler,
+                        default_args=dict(
                             optimizer=optim_wrapper,
-                            **_scheduler,
-                            epoch_length=len(
-                                self.train_dataloader),  # type: ignore
-                        ))
-                else:
-                    param_schedulers.append(
-                        PARAM_SCHEDULERS.build(
-                            _scheduler,
-                            default_args=dict(optimizer=optim_wrapper)))
+                            epoch_length=len(self.train_dataloader))))
             else:
                 raise TypeError(
                     'scheduler should be a _ParamScheduler object or dict, '
                     f'but got {scheduler}')
-
         return param_schedulers
 
     def build_param_scheduler(
diff --git a/tests/test_registry/test_build_functions.py b/tests/test_registry/test_build_functions.py
new file mode 100644
index 00000000..618f99e3
--- /dev/null
+++ b/tests/test_registry/test_build_functions.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch.nn as nn
+from torch.optim import SGD
+
+from mmengine import (PARAM_SCHEDULERS, Config, ConfigDict, ManagerMixin,
+                      Registry, build_from_cfg, build_model_from_cfg)
+
+
+@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
+def test_build_from_cfg(cfg_type):
+    BACKBONES = Registry('backbone')
+
+    @BACKBONES.register_module()
+    class ResNet:
+
+        def __init__(self, depth, stages=4):
+            self.depth = depth
+            self.stages = stages
+
+    @BACKBONES.register_module()
+    class ResNeXt:
+
+        def __init__(self, depth, stages=4):
+            self.depth = depth
+            self.stages = stages
+
+    # test `cfg` parameter
+    # `cfg` should be a dict, ConfigDict or Config object
+    with pytest.raises(
+            TypeError,
+            match=('cfg should be a dict, ConfigDict or Config, but got '
+                   "<class 'str'>")):
+        cfg = 'ResNet'
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # `cfg` is a dict, ConfigDict or Config object
+    cfg = cfg_type(dict(type='ResNet', depth=50))
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # `cfg` is a dict but it does not contain the key "type"
+    with pytest.raises(KeyError, match='must contain the key "type"'):
+        cfg = dict(depth=50, stages=4)
+        cfg = cfg_type(cfg)
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # cfg['type'] should be a str or class
+    with pytest.raises(
+            TypeError,
+            match="type must be a str or valid type, but got <class 'int'>"):
+        cfg = dict(type=1000)
+        cfg = cfg_type(cfg)
+        model = build_from_cfg(cfg, BACKBONES)
+
+    cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNeXt)
+    assert model.depth == 50 and model.stages == 3
+
+    cfg = cfg_type(dict(type=ResNet, depth=50))
+    model = build_from_cfg(cfg, BACKBONES)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # non-registered class
+    with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
+        cfg = cfg_type(dict(type='VGG'))
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # `cfg` contains unexpected arguments
+    with pytest.raises(TypeError):
+        cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
+        model = build_from_cfg(cfg, BACKBONES)
+
+    # test `default_args` parameter
+    cfg = cfg_type(dict(type='ResNet', depth=50))
+    model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 3
+
+    # default_args must be a dict or None
+    with pytest.raises(TypeError):
+        cfg = cfg_type(dict(type='ResNet', depth=50))
+        model = build_from_cfg(cfg, BACKBONES, default_args=1)
+
+    # cfg or default_args should contain the key "type"
+    with pytest.raises(KeyError, match='must contain the key "type"'):
+        cfg = cfg_type(dict(depth=50))
+        model = build_from_cfg(
+            cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
+
+    # "type" defined using default_args
+    cfg = cfg_type(dict(depth=50))
+    model = build_from_cfg(
+        cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    cfg = cfg_type(dict(depth=50))
+    model = build_from_cfg(
+        cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    # test `registry` parameter
+    # incorrect registry type
+    with pytest.raises(
+            TypeError,
+            match=('registry must be a mmengine.Registry object, but got '
+                   "<class 'str'>")):
+        cfg = cfg_type(dict(type='ResNet', depth=50))
+        model = build_from_cfg(cfg, 'BACKBONES')
+
+    VISUALIZER = Registry('visualizer')
+
+    @VISUALIZER.register_module()
+    class Visualizer(ManagerMixin):
+
+        def __init__(self, name):
+            super().__init__(name)
+
+    with pytest.raises(RuntimeError):
+        Visualizer.get_current_instance()
+    cfg = dict(type='Visualizer', name='visualizer')
+    build_from_cfg(cfg, VISUALIZER)
+    Visualizer.get_current_instance()
+
+
+def test_build_model_from_cfg():
+    BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
+
+    @BACKBONES.register_module()
+    class ResNet(nn.Module):
+
+        def __init__(self, depth, stages=4):
+            super().__init__()
+            self.depth = depth
+            self.stages = stages
+
+        def forward(self, x):
+            return x
+
+    @BACKBONES.register_module()
+    class ResNeXt(nn.Module):
+
+        def __init__(self, depth, stages=4):
+            super().__init__()
+            self.depth = depth
+            self.stages = stages
+
+        def forward(self, x):
+            return x
+
+    cfg = dict(type='ResNet', depth=50)
+    model = BACKBONES.build(cfg)
+    assert isinstance(model, ResNet)
+    assert model.depth == 50 and model.stages == 4
+
+    cfg = dict(type='ResNeXt', depth=50, stages=3)
+    model = BACKBONES.build(cfg)
+    assert isinstance(model, ResNeXt)
+    assert model.depth == 50 and model.stages == 3
+
+    cfg = [
+        dict(type='ResNet', depth=50),
+        dict(type='ResNeXt', depth=50, stages=3)
+    ]
+    model = BACKBONES.build(cfg)
+    assert isinstance(model, nn.Sequential)
+    assert isinstance(model[0], ResNet)
+    assert model[0].depth == 50 and model[0].stages == 4
+    assert isinstance(model[1], ResNeXt)
+    assert model[1].depth == 50 and model[1].stages == 3
+
+    # test inherit `build_func` from parent
+    NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
+    assert NEW_MODELS.build_func is build_model_from_cfg
+
+    # test specify `build_func`
+    def pseudo_build(cfg):
+        return cfg
+
+    NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
+    assert NEW_MODELS.build_func is pseudo_build
+
+
+def test_build_sheduler_from_cfg():
+    model = nn.Conv2d(1, 1, 1)
+    optimizer = SGD(model.parameters(), lr=0.1)
+    cfg = dict(
+        type='LinearParamScheduler',
+        optimizer=optimizer,
+        param_name='lr',
+        begin=0,
+        end=100)
+    sheduler = PARAM_SCHEDULERS.build(cfg)
+    assert sheduler.begin == 0
+    assert sheduler.end == 100
+
+    cfg = dict(
+        type='LinearParamScheduler',
+        convert_to_iter_based=True,
+        optimizer=optimizer,
+        param_name='lr',
+        begin=0,
+        end=100,
+        epoch_length=10)
+
+    sheduler = PARAM_SCHEDULERS.build(cfg)
+    assert sheduler.begin == 0
+    assert sheduler.end == 1000
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
index 41ad4684..74e51e22 100644
--- a/tests/test_registry/test_registry.py
+++ b/tests/test_registry/test_registry.py
@@ -2,12 +2,9 @@
 import time
 
 import pytest
-import torch.nn as nn
 
 from mmengine.config import Config, ConfigDict  # type: ignore
-from mmengine.registry import (DefaultScope, Registry, build_from_cfg,
-                               build_model_from_cfg)
-from mmengine.utils import ManagerMixin
+from mmengine.registry import DefaultScope, Registry, build_from_cfg
 
 
 class TestRegistry:
@@ -476,182 +473,3 @@ class TestRegistry:
             "<locals>.Munchkin'>")
         repr_str += '})'
         assert repr(CATS) == repr_str
-
-
-@pytest.mark.parametrize('cfg_type', [dict, ConfigDict, Config])
-def test_build_from_cfg(cfg_type):
-    BACKBONES = Registry('backbone')
-
-    @BACKBONES.register_module()
-    class ResNet:
-
-        def __init__(self, depth, stages=4):
-            self.depth = depth
-            self.stages = stages
-
-    @BACKBONES.register_module()
-    class ResNeXt:
-
-        def __init__(self, depth, stages=4):
-            self.depth = depth
-            self.stages = stages
-
-    # test `cfg` parameter
-    # `cfg` should be a dict, ConfigDict or Config object
-    with pytest.raises(
-            TypeError,
-            match=('cfg should be a dict, ConfigDict or Config, but got '
-                   "<class 'str'>")):
-        cfg = 'ResNet'
-        model = build_from_cfg(cfg, BACKBONES)
-
-    # `cfg` is a dict, ConfigDict or Config object
-    cfg = cfg_type(dict(type='ResNet', depth=50))
-    model = build_from_cfg(cfg, BACKBONES)
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 4
-
-    # `cfg` is a dict but it does not contain the key "type"
-    with pytest.raises(KeyError, match='must contain the key "type"'):
-        cfg = dict(depth=50, stages=4)
-        cfg = cfg_type(cfg)
-        model = build_from_cfg(cfg, BACKBONES)
-
-    # cfg['type'] should be a str or class
-    with pytest.raises(
-            TypeError,
-            match="type must be a str or valid type, but got <class 'int'>"):
-        cfg = dict(type=1000)
-        cfg = cfg_type(cfg)
-        model = build_from_cfg(cfg, BACKBONES)
-
-    cfg = cfg_type(dict(type='ResNeXt', depth=50, stages=3))
-    model = build_from_cfg(cfg, BACKBONES)
-    assert isinstance(model, ResNeXt)
-    assert model.depth == 50 and model.stages == 3
-
-    cfg = cfg_type(dict(type=ResNet, depth=50))
-    model = build_from_cfg(cfg, BACKBONES)
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 4
-
-    # non-registered class
-    with pytest.raises(KeyError, match='VGG is not in the backbone registry'):
-        cfg = cfg_type(dict(type='VGG'))
-        model = build_from_cfg(cfg, BACKBONES)
-
-    # `cfg` contains unexpected arguments
-    with pytest.raises(TypeError):
-        cfg = cfg_type(dict(type='ResNet', non_existing_arg=50))
-        model = build_from_cfg(cfg, BACKBONES)
-
-    # test `default_args` parameter
-    cfg = cfg_type(dict(type='ResNet', depth=50))
-    model = build_from_cfg(cfg, BACKBONES, cfg_type(dict(stages=3)))
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 3
-
-    # default_args must be a dict or None
-    with pytest.raises(TypeError):
-        cfg = cfg_type(dict(type='ResNet', depth=50))
-        model = build_from_cfg(cfg, BACKBONES, default_args=1)
-
-    # cfg or default_args should contain the key "type"
-    with pytest.raises(KeyError, match='must contain the key "type"'):
-        cfg = cfg_type(dict(depth=50))
-        model = build_from_cfg(
-            cfg, BACKBONES, default_args=cfg_type(dict(stages=4)))
-
-    # "type" defined using default_args
-    cfg = cfg_type(dict(depth=50))
-    model = build_from_cfg(
-        cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet')))
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 4
-
-    cfg = cfg_type(dict(depth=50))
-    model = build_from_cfg(
-        cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet)))
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 4
-
-    # test `registry` parameter
-    # incorrect registry type
-    with pytest.raises(
-            TypeError,
-            match=('registry must be a mmengine.Registry object, but got '
-                   "<class 'str'>")):
-        cfg = cfg_type(dict(type='ResNet', depth=50))
-        model = build_from_cfg(cfg, 'BACKBONES')
-
-    VISUALIZER = Registry('visualizer')
-
-    @VISUALIZER.register_module()
-    class Visualizer(ManagerMixin):
-
-        def __init__(self, name):
-            super().__init__(name)
-
-    with pytest.raises(RuntimeError):
-        Visualizer.get_current_instance()
-    cfg = dict(type='Visualizer', name='visualizer')
-    build_from_cfg(cfg, VISUALIZER)
-    Visualizer.get_current_instance()
-
-
-def test_build_model_from_cfg():
-    BACKBONES = Registry('backbone', build_func=build_model_from_cfg)
-
-    @BACKBONES.register_module()
-    class ResNet(nn.Module):
-
-        def __init__(self, depth, stages=4):
-            super().__init__()
-            self.depth = depth
-            self.stages = stages
-
-        def forward(self, x):
-            return x
-
-    @BACKBONES.register_module()
-    class ResNeXt(nn.Module):
-
-        def __init__(self, depth, stages=4):
-            super().__init__()
-            self.depth = depth
-            self.stages = stages
-
-        def forward(self, x):
-            return x
-
-    cfg = dict(type='ResNet', depth=50)
-    model = BACKBONES.build(cfg)
-    assert isinstance(model, ResNet)
-    assert model.depth == 50 and model.stages == 4
-
-    cfg = dict(type='ResNeXt', depth=50, stages=3)
-    model = BACKBONES.build(cfg)
-    assert isinstance(model, ResNeXt)
-    assert model.depth == 50 and model.stages == 3
-
-    cfg = [
-        dict(type='ResNet', depth=50),
-        dict(type='ResNeXt', depth=50, stages=3)
-    ]
-    model = BACKBONES.build(cfg)
-    assert isinstance(model, nn.Sequential)
-    assert isinstance(model[0], ResNet)
-    assert model[0].depth == 50 and model[0].stages == 4
-    assert isinstance(model[1], ResNeXt)
-    assert model[1].depth == 50 and model[1].stages == 3
-
-    # test inherit `build_func` from parent
-    NEW_MODELS = Registry('models', parent=BACKBONES, scope='new')
-    assert NEW_MODELS.build_func is build_model_from_cfg
-
-    # test specify `build_func`
-    def pseudo_build(cfg):
-        return cfg
-
-    NEW_MODELS = Registry('models', parent=BACKBONES, build_func=pseudo_build)
-    assert NEW_MODELS.build_func is pseudo_build
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index f382bc3b..ebceb5ce 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -993,11 +993,6 @@ class TestRunner(TestCase):
         # 5.1 train loop should be built before converting scheduler
         cfg = dict(
             type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
-        with self.assertRaisesRegex(
-                AssertionError,
-                'Scheduler can only be converted to iter-based when '
-                'train loop is built.'):
-            runner.build_param_scheduler(cfg)
 
         # 5.2 convert epoch-based to iter-based scheduler
         cfg = dict(
-- 
GitLab