From 563b4bad165a0f2e4b721f8cc5b48ef4f7a48edc Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Mon, 28 Mar 2022 23:14:41 +0800
Subject: [PATCH] [Feature] add defaut scope (#149)

* add defaut scope

* Fix docstring

* override get_current_instance method in DefaultScope

clean meta nameing

* remove default mmengine argument of DefaltScope

remove default mmengine argument of DefaltScope

remove default mmengine argument of DefaltScope

* Fix unit test

Fix unit test

* Fix example in docstring

* add explaination of DefaultScope
---
 mmengine/registry/__init__.py             |  4 +-
 mmengine/registry/default_scope.py        | 75 +++++++++++++++++++++++
 mmengine/runner/runner.py                 | 35 ++++++-----
 tests/test_registry/test_default_scope.py | 23 +++++++
 4 files changed, 121 insertions(+), 16 deletions(-)
 create mode 100644 mmengine/registry/default_scope.py
 create mode 100644 tests/test_registry/test_default_scope.py

diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index 069d437a..2299c17e 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -1,4 +1,5 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+from .default_scope import DefaultScope
 from .registry import Registry, build_from_cfg
 from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS,
                    MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
@@ -9,5 +10,6 @@ __all__ = [
     'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
     'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
     'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
-    'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS'
+    'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
+    'DefaultScope'
 ]
diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py
new file mode 100644
index 00000000..7d221bef
--- /dev/null
+++ b/mmengine/registry/default_scope.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
+
+
+class DefaultScope(ManagerMixin):
+    """Scope of current task used to reset the current registry, which can be
+    accessed globally.
+
+    Consider the case of reseting the current ``Resgitry`` by``default_scope``
+    in the internal module which cannot access runner directly, it is difficult
+    to get the ``default_scope`` defined in ``Runner``. However, if ``Runner``
+    created ``DefaultScope`` instance by given ``default_scope``, the internal
+    module can get ``default_scope`` by ``DefaultScope.get_current_instance``
+    everywhere.
+
+    Args:
+        name (str): Name of default scope for global access.
+        scope_name (str): Scope of current task.
+
+    Examples:
+        >>> from mmengine import MODELS
+        >>> # Define default scope in runner.
+        >>> DefaultScope.get_instance('task', scope_name='mmdet')
+        >>> # Get default scope globally.
+        >>> scope_name = DefaultScope.get_instance('task').scope_name
+        >>> # build model from cfg.
+        >>> model = MODELS.build(model_cfg, default_scope=scope_name)
+    """
+
+    def __init__(self, name: str, scope_name: str):
+        super().__init__(name)
+        self._scope_name = scope_name
+
+    @property
+    def scope_name(self) -> str:
+        """
+        Returns:
+            str: Get current scope.
+        """
+        return self._scope_name
+
+    @classmethod
+    def get_current_instance(cls) -> Optional['DefaultScope']:
+        """Get latest created default scope.
+
+        Since default_scope is an optional argument for ``Registry.build``.
+        ``get_current_instance`` should return ``None`` if there is no
+        ``DefaultScope`` created.
+
+        Examples:
+            >>> default_scope = DefaultScope.get_current_instance()
+            >>> # There is no `DefaultScope` created yet,
+            >>> # `get_current_instance` return `None`.
+            >>> default_scope = DefaultScope.get_instance(
+            >>>     'instance_name', scope_name='mmengine')
+            >>> default_scope.scope_name
+            mmengine
+            >>> default_scope = DefaultScope.get_current_instance()
+            >>> default_scope.scope_name
+            mmengine
+
+        Returns:
+            Optional[DefaultScope]: Return None If there has not been
+            ``DefaultScope`` instance created yet, otherwise return the
+            latest created DefaultScope instance.
+        """
+        _accquire_lock()
+        if cls._instance_dict:
+            instance = super(DefaultScope, cls).get_current_instance()
+        else:
+            instance = None
+        _release_lock()
+        return instance
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 42f61ced..6d458c02 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -30,7 +30,8 @@ from mmengine.logging import MessageHub, MMLogger
 from mmengine.model import is_model_wrapper
 from mmengine.optim import _ParamScheduler, build_optimizer
 from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
-                               MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS)
+                               MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
+                               DefaultScope)
 from mmengine.utils import find_latest_checkpoint, is_list_of, symlink
 from mmengine.visualization import ComposedWriter
 from .base_loop import BaseLoop
@@ -226,10 +227,6 @@ class Runner:
         else:
             self.cfg = dict()
 
-        # Used to reset registries location. See :meth:`Registry.build` for
-        # more details.
-        self.default_scope = default_scope
-
         self._epoch = 0
         self._iter = 0
         self._inner_iter = 0
@@ -305,6 +302,10 @@ class Runner:
         self.message_hub = self.build_message_hub(message_hub)
         # writer used for writing log or visualizing all kinds of data
         self.writer = self.build_writer(writer)
+        # Used to reset registries location. See :meth:`Registry.build` for
+        # more details.
+        self.default_scope = DefaultScope.get_instance(
+            self._experiment_name, scope_name=default_scope)
 
         self._load_from = load_from
         self._resume = resume
@@ -684,7 +685,8 @@ class Runner:
         if isinstance(model, nn.Module):
             return model
         elif isinstance(model, dict):
-            return MODELS.build(model, default_scope=self.default_scope)
+            return MODELS.build(
+                model, default_scope=self.default_scope.scope_name)
         else:
             raise TypeError('model should be a nn.Module object or dict, '
                             f'but got {model}')
@@ -735,7 +737,7 @@ class Runner:
         else:
             model = MODEL_WRAPPERS.build(
                 model_wrapper_cfg,
-                default_scope=self.default_scope,
+                default_scope=self.default_scope.scope_name,
                 default_args=dict(model=self.model))
 
         return model
@@ -759,7 +761,9 @@ class Runner:
             return optimizer
         elif isinstance(optimizer, dict):
             optimizer = build_optimizer(
-                self.model, optimizer, default_scope=self.default_scope)
+                self.model,
+                optimizer,
+                default_scope=self.default_scope.scope_name)
             return optimizer
         else:
             raise TypeError('optimizer should be an Optimizer object or dict, '
@@ -807,7 +811,7 @@ class Runner:
                 param_schedulers.append(
                     PARAM_SCHEDULERS.build(
                         _scheduler,
-                        default_scope=self.default_scope,
+                        default_scope=self.default_scope.scope_name,
                         default_args=dict(optimizer=self.optimizer)))
             else:
                 raise TypeError(
@@ -844,7 +848,8 @@ class Runner:
             return evaluator
         elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
             return build_evaluator(
-                evaluator, default_scope=self.default_scope)  # type: ignore
+                evaluator,
+                default_scope=self.default_scope.scope_name)  # type: ignore
         else:
             raise TypeError(
                 'evaluator should be one of dict, list of dict, BaseEvaluator '
@@ -886,7 +891,7 @@ class Runner:
         dataset_cfg = dataloader_cfg.pop('dataset')
         if isinstance(dataset_cfg, dict):
             dataset = DATASETS.build(
-                dataset_cfg, default_scope=self.default_scope)
+                dataset_cfg, default_scope=self.default_scope.scope_name)
         else:
             # fallback to raise error in dataloader
             # if `dataset_cfg` is not a valid type
@@ -897,7 +902,7 @@ class Runner:
         if isinstance(sampler_cfg, dict):
             sampler = DATA_SAMPLERS.build(
                 sampler_cfg,
-                default_scope=self.default_scope,
+                default_scope=self.default_scope.scope_name,
                 default_args=dict(dataset=dataset))
         else:
             # fallback to raise error in dataloader
@@ -966,7 +971,7 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope,
+                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self, dataloader=self.train_dataloader))
         else:
@@ -1017,7 +1022,7 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope,
+                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self,
                     dataloader=self.val_dataloader,
@@ -1064,7 +1069,7 @@ class Runner:
         if 'type' in loop_cfg:
             loop = LOOPS.build(
                 loop_cfg,
-                default_scope=self.default_scope,
+                default_scope=self.default_scope.scope_name,
                 default_args=dict(
                     runner=self,
                     dataloader=self.test_dataloader,
diff --git a/tests/test_registry/test_default_scope.py b/tests/test_registry/test_default_scope.py
new file mode 100644
index 00000000..89a894e2
--- /dev/null
+++ b/tests/test_registry/test_default_scope.py
@@ -0,0 +1,23 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import pytest
+
+from mmengine.registry import DefaultScope
+
+
+class TestDefaultScope:
+
+    def test_scope(self):
+        default_scope = DefaultScope.get_instance('name1', scope_name='mmdet')
+        assert default_scope.scope_name == 'mmdet'
+        # `DefaultScope.get_instance` must have `scope_name` argument.
+        with pytest.raises(TypeError):
+            DefaultScope.get_instance('name2')
+
+    def test_get_current_instance(self):
+        DefaultScope._instance_dict = OrderedDict()
+        assert DefaultScope.get_current_instance() is None
+        DefaultScope.get_instance('instance_name', scope_name='mmengine')
+        default_scope = DefaultScope.get_current_instance()
+        assert default_scope.scope_name == 'mmengine'
-- 
GitLab