diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py index 46eef72535636c8c54e97647e76800bf9e8e83d7..2f4d90ee6f6018ed7ff6a3f7293940217f823f7c 100644 --- a/mmengine/registry/default_scope.py +++ b/mmengine/registry/default_scope.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional +import copy +import time +from contextlib import contextmanager +from typing import Generator, Optional from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock @@ -71,3 +74,17 @@ class DefaultScope(ManagerMixin): instance = None _release_lock() return instance + + @classmethod + @contextmanager + def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator: + """overwrite the current default scope with `scope_name`""" + if scope_name is None: + yield + else: + tmp = copy.deepcopy(cls._instance_dict) + cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name) + try: + yield + finally: + cls._instance_dict = tmp diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 9bfd25a417d34ec8b0d24c58cee89c1b62aecbbe..4690ee4b1a1d45706c54fc7b16a10464a51ca1ed 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -468,7 +468,7 @@ class Registry: return None - def build(self, *args, **kwargs) -> Any: + def build(self, cfg, *args, **kwargs) -> Any: """Build an instance. Build an instance by calling :attr:`build_func`. If the global @@ -487,27 +487,28 @@ class Registry: >>> cfg = dict(type='ResNet', depth=50) >>> model = MODELS.build(cfg) """ - # get the global default scope - default_scope = DefaultScope.get_current_instance() - if default_scope is not None: - scope_name = default_scope.scope_name - root = self._get_root_registry() - registry = root._search_child(scope_name) - if registry is None: - # if `default_scope` can not be found, fallback to use self - warnings.warn( - f'Failed to search registry with scope "{scope_name}" in ' - f'the "{root.name}" registry tree. ' - f'As a workaround, the current "{self.name}" registry in ' - f'"{self.scope}" is used to build instance. This may ' - f'cause unexpected failure when running the built ' - f'modules. Please check whether "{scope_name}" is a ' - f'correct scope, or whether the registry is initialized.') + with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)): + # get the global default scope + default_scope = DefaultScope.get_current_instance() + if default_scope is not None: + scope_name = default_scope.scope_name + root = self._get_root_registry() + registry = root._search_child(scope_name) + if registry is None: + # if `default_scope` can not be found, fallback to use self + warnings.warn( + f'Failed to search registry with scope "{scope_name}" ' + f'in the "{root.name}" registry tree. ' + f'As a workaround, the current "{self.name}" registry ' + f'in "{self.scope}" is used to build instance. This ' + f'may cause unexpected failure when running the built ' + f'modules. Please check whether "{scope_name}" is a ' + f'correct scope, or whether the registry is ' + f'initialized.') + registry = self + else: registry = self - else: - registry = self - - return registry.build_func(*args, **kwargs, registry=registry) + return registry.build_func(cfg, *args, **kwargs, registry=registry) def _add_child(self, registry: 'Registry') -> None: """Add a child for a registry. diff --git a/tests/test_registry/test_default_scope.py b/tests/test_registry/test_default_scope.py index 89a894e2e36485f0ff4eca30522a16508d29b2be..3846b15adef13300e92c72713f551566633dc7d5 100644 --- a/tests/test_registry/test_default_scope.py +++ b/tests/test_registry/test_default_scope.py @@ -21,3 +21,15 @@ class TestDefaultScope: DefaultScope.get_instance('instance_name', scope_name='mmengine') default_scope = DefaultScope.get_current_instance() assert default_scope.scope_name == 'mmengine' + + def test_overwrite_default_scope(self): + origin_scope = DefaultScope.get_instance( + 'test_overwrite_default_scope', scope_name='origin_scope') + with DefaultScope.overwrite_default_scope(scope_name=None): + assert DefaultScope.get_current_instance( + ).scope_name == 'origin_scope' + with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'): + assert DefaultScope.get_current_instance( + ).scope_name == 'test_overwrite' + assert DefaultScope.get_current_instance( + ).scope_name == origin_scope.scope_name == 'origin_scope' diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 2cc2ec848c88184295d322f9f6fab304cf97a4a0..9070df0970003f425004c3dbc28623cfe0320eea 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -167,16 +167,17 @@ class TestRegistry: registries = [] DOGS = Registry('dogs') registries.append(DOGS) - HOUNDS = Registry('dogs', parent=DOGS, scope='hound') + HOUNDS = Registry('hounds', parent=DOGS, scope='hound') registries.append(HOUNDS) - LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound') + LITTLE_HOUNDS = Registry( + 'little hounds', parent=HOUNDS, scope='little_hound') registries.append(LITTLE_HOUNDS) - MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound') + MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound') registries.append(MID_HOUNDS) - SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed') + SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed') registries.append(SAMOYEDS) LITTLE_SAMOYEDS = Registry( - 'dogs', parent=SAMOYEDS, scope='little_samoyed') + 'little samoyeds', parent=SAMOYEDS, scope='little_samoyed') registries.append(LITTLE_SAMOYEDS) return registries @@ -323,7 +324,7 @@ class TestRegistry: # LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS # (little_hound) (mid_hound) (little_samoyed) registries = self._build_registry() - DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4] + DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5] @DOGS.register_module() class GoldenRetriever: @@ -367,6 +368,37 @@ class TestRegistry: dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) + # test overwrite default scope with `_scope_` + @SAMOYEDS.register_module() + class MySamoyed: + + def __init__(self, friend): + self.friend = DOGS.build(friend) + + @SAMOYEDS.register_module() + class YourSamoyed: + pass + + s_cfg = cfg_type( + dict( + _scope_='samoyed', + type='MySamoyed', + friend=dict(type='hound.BloodHound'))) + dog = DOGS.build(s_cfg) + assert isinstance(dog, MySamoyed) + assert isinstance(dog.friend, BloodHound) + assert DefaultScope.get_current_instance().scope_name != 'samoyed' + + s_cfg = cfg_type( + dict( + _scope_='samoyed', + type='MySamoyed', + friend=dict(type='YourSamoyed'))) + dog = DOGS.build(s_cfg) + assert isinstance(dog, MySamoyed) + assert isinstance(dog.friend, YourSamoyed) + assert DefaultScope.get_current_instance().scope_name != 'samoyed' + def test_repr(self): CATS = Registry('cat')