From 25014af3c336717175537862e0fd72e8e20d2078 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Fri, 1 Apr 2022 09:13:55 +0800
Subject: [PATCH] [Refactor] Refactor default_scope in Registry. (#158)

---
 docs/zh_cn/tutorials/registry.md              | 18 +++++++++----
 mmengine/evaluator/builder.py                 | 14 +++-------
 mmengine/optim/optimizer/builder.py           | 12 +++------
 .../optim/optimizer/default_constructor.py    |  8 +++---
 mmengine/registry/default_scope.py            |  2 --
 mmengine/registry/registry.py                 | 21 +++++++--------
 mmengine/runner/runner.py                     | 27 +++++--------------
 tests/test_registry/test_registry.py          | 12 ++++++---
 8 files changed, 48 insertions(+), 66 deletions(-)

diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md
index 2febb6c2..09a38058 100644
--- a/docs/zh_cn/tutorials/registry.md
+++ b/docs/zh_cn/tutorials/registry.md
@@ -311,11 +311,15 @@ from mmcls.models import MODELS
 model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
 ```
 
-调用兄弟节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以将 `build` 方法中的 `default_scope` 参数设置为 'mmdet',它会将 `default_scope` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
+调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
 
 ```python
-from mmcls.models import MODELS
-model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
+from mmengine.registry import DefaultScope, MODELS
+
+# 调用注册在 mmdet 中的 RetinaNet
+default_scope = DefaultScope.get_instance(
+            'my_experiment', scope_name='mmdet')
+model = MODELS.build(cfg=dict(type='RetinaNet'))
 ```
 
 注册器除了支持两层结构,三层甚至更多层结构也是支持的。
@@ -325,7 +329,7 @@ model = MODELS.build(cfg=dict(type='RetinaNet'), default_scope='mmdet')
 `DetPlus` 中定义了模块 `MetaNet`,
 
 ```python
-from mmengine.model import Registry
+from mmengine.registry import Registry
 from mmdet.model import MODELS as MMDET_MODELS
 MODELS = Registry('model', parent=MMDET_MODELS, scope='det_plus')
 
@@ -354,6 +358,10 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
 from mmcls.models import MODELS
 # 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
 model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
-# 当然,更简单的方法是直接设置 default_scope
+
+# 如果希望默认从 detplus 构建模型,设置可以 default_scope
+from mmengine.registry import DefaultScope
+default_scope = DefaultScope.get_instance(
+            'my_experiment', scope_name='detplus')
 model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
 ```
diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py
index fcc80031..40fa03a3 100644
--- a/mmengine/evaluator/builder.py
+++ b/mmengine/evaluator/builder.py
@@ -1,5 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional, Union
+from typing import Union
 
 from ..registry import EVALUATORS
 from .base import BaseEvaluator
@@ -7,9 +7,7 @@ from .composed_evaluator import ComposedEvaluator
 
 
 def build_evaluator(
-    cfg: Union[dict, list],
-    default_scope: Optional[str] = None
-) -> Union[BaseEvaluator, ComposedEvaluator]:
+        cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
     """Build function of evaluator.
 
     When the evaluator config is a list, it will automatically build composed
@@ -18,16 +16,12 @@ def build_evaluator(
     Args:
         cfg (dict | list): Config of evaluator. When the config is a list, it
             will automatically build composed evaluators.
-        default_scope (str, optional): The ``default_scope`` is used to
-            reset the current registry. Defaults to None.
 
     Returns:
         BaseEvaluator or ComposedEvaluator: The built evaluator.
     """
     if isinstance(cfg, list):
-        evaluators = [
-            EVALUATORS.build(_cfg, default_scope=default_scope) for _cfg in cfg
-        ]
+        evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
         return ComposedEvaluator(evaluators=evaluators)
     else:
-        return EVALUATORS.build(cfg, default_scope=default_scope)
+        return EVALUATORS.build(cfg)
diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py
index a3e1612d..31350f6f 100644
--- a/mmengine/optim/optimizer/builder.py
+++ b/mmengine/optim/optimizer/builder.py
@@ -1,7 +1,7 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
 import inspect
-from typing import List, Optional
+from typing import List
 
 import torch
 import torch.nn as nn
@@ -30,10 +30,7 @@ def register_torch_optimizers() -> List[str]:
 TORCH_OPTIMIZERS = register_torch_optimizers()
 
 
-def build_optimizer(
-        model: nn.Module,
-        cfg: dict,
-        default_scope: Optional[str] = None) -> torch.optim.Optimizer:
+def build_optimizer(model: nn.Module, cfg: dict) -> torch.optim.Optimizer:
     """Build function of optimizer.
 
     If ``constructor`` is set in the ``cfg``, this method will build an
@@ -58,7 +55,6 @@ def build_optimizer(
         dict(
             type=constructor_type,
             optimizer_cfg=optimizer_cfg,
-            paramwise_cfg=paramwise_cfg),
-        default_scope=default_scope)
-    optimizer = optim_constructor(model, default_scope=default_scope)
+            paramwise_cfg=paramwise_cfg))
+    optimizer = optim_constructor(model)
     return optimizer
diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py
index 18b9db47..f46cd208 100644
--- a/mmengine/optim/optimizer/default_constructor.py
+++ b/mmengine/optim/optimizer/default_constructor.py
@@ -241,9 +241,7 @@ class DefaultOptimizerConstructor:
                 prefix=child_prefix,
                 is_dcn_module=is_dcn_module)
 
-    def __call__(self,
-                 model: nn.Module,
-                 default_scope: Optional[str] = None) -> torch.optim.Optimizer:
+    def __call__(self, model: nn.Module) -> torch.optim.Optimizer:
         if hasattr(model, 'module'):
             model = model.module
 
@@ -251,11 +249,11 @@ class DefaultOptimizerConstructor:
         # if no paramwise option is specified, just use the global setting
         if not self.paramwise_cfg:
             optimizer_cfg['params'] = model.parameters()
-            return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
+            return OPTIMIZERS.build(optimizer_cfg)
 
         # set param-wise lr and weight decay recursively
         params: List = []
         self.add_params(params, model)
         optimizer_cfg['params'] = params
 
-        return OPTIMIZERS.build(optimizer_cfg, default_scope=default_scope)
+        return OPTIMIZERS.build(optimizer_cfg)
diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py
index 204ac43d..dc2256f4 100644
--- a/mmengine/registry/default_scope.py
+++ b/mmengine/registry/default_scope.py
@@ -25,8 +25,6 @@ class DefaultScope(ManagerMixin):
         >>> DefaultScope.get_instance('task', scope_name='mmdet')
         >>> # Get default scope globally.
         >>> scope_name = DefaultScope.get_instance('task').scope_name
-        >>> # build model from cfg.
-        >>> model = MODELS.build(model_cfg, default_scope=scope_name)
     """
 
     def __init__(self, name: str, scope_name: str):
diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py
index f0e59a7e..3ee7d4d6 100644
--- a/mmengine/registry/registry.py
+++ b/mmengine/registry/registry.py
@@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
 
 from ..config import Config, ConfigDict
 from ..utils import is_seq_of
+from .default_scope import DefaultScope
 
 
 def build_from_cfg(
@@ -354,19 +355,13 @@ class Registry:
 
         return None
 
-    def build(self,
-              *args,
-              default_scope: Optional[str] = None,
-              **kwargs) -> Any:
+    def build(self, *args, **kwargs) -> Any:
         """Build an instance.
 
-        Build an instance by calling :attr:`build_func`. If
-        :attr:`default_scope` is given, :meth:`build` will firstly get the
-        responding registry and then call its own :meth:`build`.
-
-        Args:
-            default_scope (str, optional): The ``default_scope`` is used to
-                reset the current registry. Defaults to None.
+        Build an instance by calling :attr:`build_func`. If the global
+        variable default scope (:obj:`DefaultScope`) exists ,
+        :meth:`build` will firstly get the responding registry and then call
+        its own :meth:`build`.
 
         Examples:
             >>> from mmengine import Registry
@@ -379,9 +374,11 @@ class Registry:
             >>> cfg = dict(type='ResNet', depth=50)
             >>> model = MODELS.build(cfg)
         """
+        # get the global default scope
+        default_scope = DefaultScope.get_current_instance()
         if default_scope is not None:
             root = self._get_root_registry()
-            registry = root._search_child(default_scope)
+            registry = root._search_child(default_scope.scope_name)
             if registry is None:
                 # if `default_scope` can not be found, fallback to use self
                 warnings.warn(
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 5cb15fad..fea3dff2 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -675,8 +675,7 @@ class Runner:
         if isinstance(model, nn.Module):
             return model
         elif isinstance(model, dict):
-            return MODELS.build(
-                model, default_scope=self.default_scope.scope_name)
+            return MODELS.build(model)
         else:
             raise TypeError('model should be a nn.Module object or dict, '
                             f'but got {model}')
@@ -726,9 +725,7 @@ class Runner:
                     model = model.cuda()
         else:
             model = MODEL_WRAPPERS.build(
-                model_wrapper_cfg,
-                default_scope=self.default_scope.scope_name,
-                default_args=dict(model=self.model))
+                model_wrapper_cfg, default_args=dict(model=self.model))
 
         return model
 
@@ -750,10 +747,7 @@ class Runner:
         if isinstance(optimizer, Optimizer):
             return optimizer
         elif isinstance(optimizer, dict):
-            optimizer = build_optimizer(
-                self.model,
-                optimizer,
-                default_scope=self.default_scope.scope_name)
+            optimizer = build_optimizer(self.model, optimizer)
             return optimizer
         else:
             raise TypeError('optimizer should be an Optimizer object or dict, '
@@ -801,7 +795,6 @@ class Runner:
                 param_schedulers.append(
                     PARAM_SCHEDULERS.build(
                         _scheduler,
-                        default_scope=self.default_scope.scope_name,
                         default_args=dict(optimizer=self.optimizer)))
             else:
                 raise TypeError(
@@ -837,9 +830,7 @@ class Runner:
         if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
             return evaluator
         elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
-            return build_evaluator(
-                evaluator,
-                default_scope=self.default_scope.scope_name)  # type: ignore
+            return build_evaluator(evaluator)  # type: ignore
         else:
             raise TypeError(
                 'evaluator should be one of dict, list of dict, BaseEvaluator '
@@ -880,8 +871,7 @@ class Runner:
         # build dataset
         dataset_cfg = dataloader_cfg.pop('dataset')
         if isinstance(dataset_cfg, dict):
-            dataset = DATASETS.build(
-                dataset_cfg, default_scope=self.default_scope.scope_name)
+            dataset = DATASETS.build(dataset_cfg)
         else:
             # fallback to raise error in dataloader
             # if `dataset_cfg` is not a valid type
@@ -891,9 +881,7 @@ class Runner:
         sampler_cfg = dataloader_cfg.pop('sampler')
         if isinstance(sampler_cfg, dict):
             sampler = DATA_SAMPLERS.build(
-                sampler_cfg,
-                default_scope=self.default_scope.scope_name,
-                default_args=dict(dataset=dataset))
+                sampler_cfg, default_args=dict(dataset=dataset))
         else:
             # fallback to raise error in dataloader
             # if `sampler_cfg` is not a valid type
@@ -961,7 +949,6 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self, dataloader=self.train_dataloader))
         else:
@@ -1012,7 +999,6 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self,
                     dataloader=self.val_dataloader,
@@ -1059,7 +1045,6 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self,
                     dataloader=self.test_dataloader,
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
index 344762d7..76f1d7ce 100644
--- a/tests/test_registry/test_registry.py
+++ b/tests/test_registry/test_registry.py
@@ -1,8 +1,10 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import time
+
 import pytest
 
 from mmengine.config import Config, ConfigDict  # type: ignore
-from mmengine.registry import Registry, build_from_cfg
+from mmengine.registry import DefaultScope, Registry, build_from_cfg
 
 
 class TestRegistry:
@@ -342,11 +344,15 @@ class TestRegistry:
 
         # test `default_scope`
         # switch the current registry to another registry
-        dog = LITTLE_HOUNDS.build(b_cfg, default_scope='mid_hound')
+        DefaultScope.get_instance(
+            f'test-{time.time()}', scope_name='mid_hound')
+        dog = LITTLE_HOUNDS.build(b_cfg)
         assert isinstance(dog, Beagle)
 
         # `default_scope` can not be found
-        dog = MID_HOUNDS.build(b_cfg, default_scope='scope-not-found')
+        DefaultScope.get_instance(
+            f'test2-{time.time()}', scope_name='scope-not-found')
+        dog = MID_HOUNDS.build(b_cfg)
         assert isinstance(dog, Beagle)
 
     def test_repr(self):
-- 
GitLab