From 25014af3c336717175537862e0fd72e8e20d2078 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Fri, 1 Apr 2022 09:13:55 +0800 Subject: [PATCH] [Refactor] Refactor default_scope in Registry. (#158) --- docs/zh_cn/tutorials/registry.md | 18 +++++++++---- mmengine/evaluator/builder.py | 14 +++------- mmengine/optim/optimizer/builder.py | 12 +++------ .../optim/optimizer/default_constructor.py | 8 +++--- mmengine/registry/default_scope.py | 2 -- mmengine/registry/registry.py | 21 +++++++-------- mmengine/runner/runner.py | 27 +++++-------------- tests/test_registry/test_registry.py | 12 ++++++--- 8 files changed, 48 insertions(+), 66 deletions(-) diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 2febb6c2..09a38058 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -311,11 +311,15 @@ from mmcls.models import MODELS model = MODELS.build(cfg=dict(type='mmdet.RetinaNet')) ``` -调用兄弟节点的模å—需è¦æŒ‡å®šåœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ï¼Œå¦‚æžœä¸æƒ³æŒ‡å®šï¼Œæˆ‘们å¯ä»¥å°† `build` 方法ä¸çš„ `default_scope` å‚数设置为 'mmdet',它会将 `default_scope` 对应的 `registry` ä½œä¸ºå½“å‰ `Registry` 并调用 `build` 方法。 +调用éžæœ¬èŠ‚点的模å—需è¦æŒ‡å®šåœ¨ `type` ä¸æŒ‡å®š `scope` å‰ç¼€ï¼Œå¦‚æžœä¸æƒ³æŒ‡å®šï¼Œæˆ‘们å¯ä»¥åˆ›å»ºä¸€ä¸ªå…¨å±€å˜é‡ `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` ä½œä¸ºå½“å‰ `Registry` 并调用 `build` 方法。 ```python -from mmcls.models import MODELS -model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet') +from mmengine.registry import DefaultScope, MODELS + +# 调用注册在 mmdet ä¸çš„ RetinaNet +default_scope = DefaultScope.get_instance( + 'my_experiment', scope_name='mmdet') +model = MODELS.build(cfg=dict(type='RetinaNet')) ``` 注册器除了支æŒä¸¤å±‚结构,三层甚至更多层结构也是支æŒçš„。 @@ -325,7 +329,7 @@ model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet') `DetPlus` ä¸å®šä¹‰äº†æ¨¡å— `MetaNet`, ```python -from mmengine.model import Registry +from mmengine.registry import Registry from mmdet.model import MODELS as MMDET_MODELS MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus') @@ -354,6 +358,10 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet')) from mmcls.models import MODELS # 需è¦æ³¨æ„å‰ç¼€çš„顺åºï¼Œ'detplus.mmdet.ResNet' 是ä¸æ£ç¡®çš„ model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet')) -# 当然,更简å•çš„方法是直接设置 default_scope + +# 如果希望默认从 detplus 构建模型,设置å¯ä»¥ default_scope +from mmengine.registry import DefaultScope +default_scope = DefaultScope.get_instance( + 'my_experiment', scope_name='detplus') model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus')) ``` diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py index fcc80031..40fa03a3 100644 --- a/mmengine/evaluator/builder.py +++ b/mmengine/evaluator/builder.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Union +from typing import Union from ..registry import EVALUATORS from .base import BaseEvaluator @@ -7,9 +7,7 @@ from .composed_evaluator import ComposedEvaluator def build_evaluator( - cfg: Union[dict, list], - default_scope: Optional[str] = None -) -> Union[BaseEvaluator, ComposedEvaluator]: + cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]: """Build function of evaluator. When the evaluator config is a list, it will automatically build composed @@ -18,16 +16,12 @@ def build_evaluator( Args: cfg (dict | list): Config of evaluator. When the config is a list, it will automatically build composed evaluators. - default_scope (str, optional): The ``default_scope`` is used to - reset the current registry. Defaults to None. Returns: BaseEvaluator or ComposedEvaluator: The built evaluator. """ if isinstance(cfg, list): - evaluators = [ - EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg - ] + evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg] return ComposedEvaluator(evaluators=evaluators) else: - return EVALUATORS.build(cfg, default_scope=default_scope) + return EVALUATORS.build(cfg) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index a3e1612d..31350f6f 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect -from typing import List, Optional +from typing import List import torch import torch.nn as nn @@ -30,10 +30,7 @@ def register_torch_optimizers() -> List[str]: TORCH_OPTIMIZERS = register_torch_optimizers() -def build_optimizer( - model: nn.Module, - cfg: dict, - default_scope: Optional[str] = None) -> torch.optim.Optimizer: +def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer: """Build function of optimizer. If ``constructor`` is set in the ``cfg``, this method will build an @@ -58,7 +55,6 @@ def build_optimizer( dict( type=constructor_type, optimizer_cfg=optimizer_cfg, - paramwise_cfg=paramwise_cfg), - default_scope=default_scope) - optimizer = optim_constructor(model, default_scope=default_scope) + paramwise_cfg=paramwise_cfg)) + optimizer = optim_constructor(model) return optimizer diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index 18b9db47..f46cd208 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -241,9 +241,7 @@ class DefaultOptimizerConstructor: prefix=child_prefix, is_dcn_module=is_dcn_module) - def __call__(self, - model: nn.Module, - default_scope: Optional[str] = None) -> torch.optim.Optimizer: + def __call__(self, model: nn.Module) -> torch.optim.Optimizer: if hasattr(model, 'module'): model = model.module @@ -251,11 +249,11 @@ class DefaultOptimizerConstructor: # if no paramwise option is specified, just use the global setting if not self.paramwise_cfg: optimizer_cfg['params'] = model.parameters() - return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope) + return OPTIMIZERS.build(optimizer_cfg) # set param-wise lr and weight decay recursively params: List = [] self.add_params(params, model) optimizer_cfg['params'] = params - return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope) + return OPTIMIZERS.build(optimizer_cfg) diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py index 204ac43d..dc2256f4 100644 --- a/mmengine/registry/default_scope.py +++ b/mmengine/registry/default_scope.py @@ -25,8 +25,6 @@ class DefaultScope(ManagerMixin): >>> DefaultScope.get_instance('task', scope_name='mmdet') >>> # Get default scope globally. >>> scope_name = DefaultScope.get_instance('task').scope_name - >>> # build model from cfg. - >>> model = MODELS.build(model_cfg, default_scope=scope_name) """ def __init__(self, name: str, scope_name: str): diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index f0e59a7e..3ee7d4d6 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union from ..config import Config, ConfigDict from ..utils import is_seq_of +from .default_scope import DefaultScope def build_from_cfg( @@ -354,19 +355,13 @@ class Registry: return None - def build(self, - *args, - default_scope: Optional[str] = None, - **kwargs) -> Any: + def build(self, *args, **kwargs) -> Any: """Build an instance. - Build an instance by calling :attr:`build_func`. If - :attr:`default_scope` is given, :meth:`build` will firstly get the - responding registry and then call its own :meth:`build`. - - Args: - default_scope (str, optional): The ``default_scope`` is used to - reset the current registry. Defaults to None. + Build an instance by calling :attr:`build_func`. If the global + variable default scope (:obj:`DefaultScope`) exists , + :meth:`build` will firstly get the responding registry and then call + its own :meth:`build`. Examples: >>> from mmengine import Registry @@ -379,9 +374,11 @@ class Registry: >>> cfg = dict(type='ResNet', depth=50) >>> model = MODELS.build(cfg) """ + # get the global default scope + default_scope = DefaultScope.get_current_instance() if default_scope is not None: root = self._get_root_registry() - registry = root._search_child(default_scope) + registry = root._search_child(default_scope.scope_name) if registry is None: # if `default_scope` can not be found, fallback to use self warnings.warn( diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 5cb15fad..fea3dff2 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -675,8 +675,7 @@ class Runner: if isinstance(model, nn.Module): return model elif isinstance(model, dict): - return MODELS.build( - model, default_scope=self.default_scope.scope_name) + return MODELS.build(model) else: raise TypeError('model should be a nn.Module object or dict, ' f'but got {model}') @@ -726,9 +725,7 @@ class Runner: model = model.cuda() else: model = MODEL_WRAPPERS.build( - model_wrapper_cfg, - default_scope=self.default_scope.scope_name, - default_args=dict(model=self.model)) + model_wrapper_cfg, default_args=dict(model=self.model)) return model @@ -750,10 +747,7 @@ class Runner: if isinstance(optimizer, Optimizer): return optimizer elif isinstance(optimizer, dict): - optimizer = build_optimizer( - self.model, - optimizer, - default_scope=self.default_scope.scope_name) + optimizer = build_optimizer(self.model, optimizer) return optimizer else: raise TypeError('optimizer should be an Optimizer object or dict, ' @@ -801,7 +795,6 @@ class Runner: param_schedulers.append( PARAM_SCHEDULERS.build( _scheduler, - default_scope=self.default_scope.scope_name, default_args=dict(optimizer=self.optimizer))) else: raise TypeError( @@ -837,9 +830,7 @@ class Runner: if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)): return evaluator elif isinstance(evaluator, dict) or is_list_of(evaluator, dict): - return build_evaluator( - evaluator, - default_scope=self.default_scope.scope_name) # type: ignore + return build_evaluator(evaluator) # type: ignore else: raise TypeError( 'evaluator should be one of dict, list of dict, BaseEvaluator ' @@ -880,8 +871,7 @@ class Runner: # build dataset dataset_cfg = dataloader_cfg.pop('dataset') if isinstance(dataset_cfg, dict): - dataset = DATASETS.build( - dataset_cfg, default_scope=self.default_scope.scope_name) + dataset = DATASETS.build(dataset_cfg) else: # fallback to raise error in dataloader # if `dataset_cfg` is not a valid type @@ -891,9 +881,7 @@ class Runner: sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_scope=self.default_scope.scope_name, - default_args=dict(dataset=dataset)) + sampler_cfg, default_args=dict(dataset=dataset)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type @@ -961,7 +949,6 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.train_dataloader)) else: @@ -1012,7 +999,6 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.val_dataloader, @@ -1059,7 +1045,6 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.test_dataloader, diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 344762d7..76f1d7ce 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import time + import pytest from mmengine.config import Config, ConfigDict # type: ignore -from mmengine.registry import Registry, build_from_cfg +from mmengine.registry import DefaultScope, Registry, build_from_cfg class TestRegistry: @@ -342,11 +344,15 @@ class TestRegistry: # test `default_scope` # switch the current registry to another registry - dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound') + DefaultScope.get_instance( + f'test-{time.time()}', scope_name='mid_hound') + dog = LITTLE_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # `default_scope` can not be found - dog = MID_HOUNDS.build(b_cfg, default_scope='scope-not-found') + DefaultScope.get_instance( + f'test2-{time.time()}', scope_name='scope-not-found') + dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) def test_repr(self): -- GitLab