Skip to content
Snippets Groups Projects
Unverified Commit 25014af3 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Refactor] Refactor default_scope in Registry. (#158)

parent e80267ae
No related branches found
No related tags found
No related merge requests found
......@@ -311,11 +311,15 @@ from mmcls.models import MODELS
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
```
调用兄弟节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以`build` 方法中的 `default_scope` 参数设置为 'mmdet',它会将 `default_scope` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
```python
from mmcls.models import MODELS
model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
from mmengine.registry import DefaultScope, MODELS
# 调用注册在 mmdet 中的 RetinaNet
default_scope = DefaultScope.get_instance(
'my_experiment', scope_name='mmdet')
model = MODELS.build(cfg=dict(type='RetinaNet'))
```
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
......@@ -325,7 +329,7 @@ model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
`DetPlus` 中定义了模块 `MetaNet`
```python
from mmengine.model import Registry
from mmengine.registry import Registry
from mmdet.model import MODELS as MMDET_MODELS
MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus')
......@@ -354,6 +358,10 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
from mmcls.models import MODELS
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
# 当然,更简单的方法是直接设置 default_scope
# 如果希望默认从 detplus 构建模型,设置可以 default_scope
from mmengine.registry import DefaultScope
default_scope = DefaultScope.get_instance(
'my_experiment', scope_name='detplus')
model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
```
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
from typing import Union
from ..registry import EVALUATORS
from .base import BaseEvaluator
......@@ -7,9 +7,7 @@ from .composed_evaluator import ComposedEvaluator
def build_evaluator(
cfg: Union[dict, list],
default_scope: Optional[str] = None
) -> Union[BaseEvaluator, ComposedEvaluator]:
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator.
When the evaluator config is a list, it will automatically build composed
......@@ -18,16 +16,12 @@ def build_evaluator(
Args:
cfg (dict | list): Config of evaluator. When the config is a list, it
will automatically build composed evaluators.
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Returns:
BaseEvaluator or ComposedEvaluator: The built evaluator.
"""
if isinstance(cfg, list):
evaluators = [
EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
]
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
return ComposedEvaluator(evaluators=evaluators)
else:
return EVALUATORS.build(cfg, default_scope=default_scope)
return EVALUATORS.build(cfg)
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
from typing import List, Optional
from typing import List
import torch
import torch.nn as nn
......@@ -30,10 +30,7 @@ def register_torch_optimizers() -> List[str]:
TORCH_OPTIMIZERS = register_torch_optimizers()
def build_optimizer(
model: nn.Module,
cfg: dict,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
"""Build function of optimizer.
If ``constructor`` is set in the ``cfg``, this method will build an
......@@ -58,7 +55,6 @@ def build_optimizer(
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg),
default_scope=default_scope)
optimizer = optim_constructor(model, default_scope=default_scope)
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
......@@ -241,9 +241,7 @@ class DefaultOptimizerConstructor:
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def __call__(self,
model: nn.Module,
default_scope: Optional[str] = None) -> torch.optim.Optimizer:
def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
if hasattr(model, 'module'):
model = model.module
......@@ -251,11 +249,11 @@ class DefaultOptimizerConstructor:
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
return OPTIMIZERS.build(optimizer_cfg)
# set param-wise lr and weight decay recursively
params: List = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
return OPTIMIZERS.build(optimizer_cfg)
......@@ -25,8 +25,6 @@ class DefaultScope(ManagerMixin):
>>> DefaultScope.get_instance('task', scope_name='mmdet')
>>> # Get default scope globally.
>>> scope_name = DefaultScope.get_instance('task').scope_name
>>> # build model from cfg.
>>> model = MODELS.build(model_cfg, default_scope=scope_name)
"""
def __init__(self, name: str, scope_name: str):
......
......@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ..config import Config, ConfigDict
from ..utils import is_seq_of
from .default_scope import DefaultScope
def build_from_cfg(
......@@ -354,19 +355,13 @@ class Registry:
return None
def build(self,
*args,
default_scope: Optional[str] = None,
**kwargs) -> Any:
def build(self, *args, **kwargs) -> Any:
"""Build an instance.
Build an instance by calling :attr:`build_func`. If
:attr:`default_scope` is given, :meth:`build` will firstly get the
responding registry and then call its own :meth:`build`.
Args:
default_scope (str, optional): The ``default_scope`` is used to
reset the current registry. Defaults to None.
Build an instance by calling :attr:`build_func`. If the global
variable default scope (:obj:`DefaultScope`) exists ,
:meth:`build` will firstly get the responding registry and then call
its own :meth:`build`.
Examples:
>>> from mmengine import Registry
......@@ -379,9 +374,11 @@ class Registry:
>>> cfg = dict(type='ResNet', depth=50)
>>> model = MODELS.build(cfg)
"""
# get the global default scope
default_scope = DefaultScope.get_current_instance()
if default_scope is not None:
root = self._get_root_registry()
registry = root._search_child(default_scope)
registry = root._search_child(default_scope.scope_name)
if registry is None:
# if `default_scope` can not be found, fallback to use self
warnings.warn(
......
......@@ -675,8 +675,7 @@ class Runner:
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
return MODELS.build(
model, default_scope=self.default_scope.scope_name)
return MODELS.build(model)
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
......@@ -726,9 +725,7 @@ class Runner:
model = model.cuda()
else:
model = MODEL_WRAPPERS.build(
model_wrapper_cfg,
default_scope=self.default_scope.scope_name,
default_args=dict(model=self.model))
model_wrapper_cfg, default_args=dict(model=self.model))
return model
......@@ -750,10 +747,7 @@ class Runner:
if isinstance(optimizer, Optimizer):
return optimizer
elif isinstance(optimizer, dict):
optimizer = build_optimizer(
self.model,
optimizer,
default_scope=self.default_scope.scope_name)
optimizer = build_optimizer(self.model, optimizer)
return optimizer
else:
raise TypeError('optimizer should be an Optimizer object or dict, '
......@@ -801,7 +795,6 @@ class Runner:
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_scope=self.default_scope.scope_name,
default_args=dict(optimizer=self.optimizer)))
else:
raise TypeError(
......@@ -837,9 +830,7 @@ class Runner:
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
return evaluator
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
return build_evaluator(
evaluator,
default_scope=self.default_scope.scope_name) # type: ignore
return build_evaluator(evaluator) # type: ignore
else:
raise TypeError(
'evaluator should be one of dict, list of dict, BaseEvaluator '
......@@ -880,8 +871,7 @@ class Runner:
# build dataset
dataset_cfg = dataloader_cfg.pop('dataset')
if isinstance(dataset_cfg, dict):
dataset = DATASETS.build(
dataset_cfg, default_scope=self.default_scope.scope_name)
dataset = DATASETS.build(dataset_cfg)
else:
# fallback to raise error in dataloader
# if `dataset_cfg` is not a valid type
......@@ -891,9 +881,7 @@ class Runner:
sampler_cfg = dataloader_cfg.pop('sampler')
if isinstance(sampler_cfg, dict):
sampler = DATA_SAMPLERS.build(
sampler_cfg,
default_scope=self.default_scope.scope_name,
default_args=dict(dataset=dataset))
sampler_cfg, default_args=dict(dataset=dataset))
else:
# fallback to raise error in dataloader
# if `sampler_cfg` is not a valid type
......@@ -961,7 +949,6 @@ class Runner:
if 'type' in loop_cfg:
loop = LOOPS.build(
loop_cfg,
default_scope=self.default_scope.scope_name,
default_args=dict(
runner=self, dataloader=self.train_dataloader))
else:
......@@ -1012,7 +999,6 @@ class Runner:
if 'type' in loop_cfg:
loop = LOOPS.build(
loop_cfg,
default_scope=self.default_scope.scope_name,
default_args=dict(
runner=self,
dataloader=self.val_dataloader,
......@@ -1059,7 +1045,6 @@ class Runner:
if 'type' in loop_cfg:
loop = LOOPS.build(
loop_cfg,
default_scope=self.default_scope.scope_name,
default_args=dict(
runner=self,
dataloader=self.test_dataloader,
......
# Copyright (c) OpenMMLab. All rights reserved.
import time
import pytest
from mmengine.config import Config, ConfigDict # type: ignore
from mmengine.registry import Registry, build_from_cfg
from mmengine.registry import DefaultScope, Registry, build_from_cfg
class TestRegistry:
......@@ -342,11 +344,15 @@ class TestRegistry:
# test `default_scope`
# switch the current registry to another registry
dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
DefaultScope.get_instance(
f'test-{time.time()}', scope_name='mid_hound')
dog = LITTLE_HOUNDS.build(b_cfg)
assert isinstance(dog, Beagle)
# `default_scope` can not be found
dog = MID_HOUNDS.build(b_cfg, default_scope='scope-not-found')
DefaultScope.get_instance(
f'test2-{time.time()}', scope_name='scope-not-found')
dog = MID_HOUNDS.build(b_cfg)
assert isinstance(dog, Beagle)
def test_repr(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment