Skip to content
Snippets Groups Projects
Unverified Commit 2f16ec69 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Feature] Support overwrite default scope with "_scope_". (#275)

* [Feature] Support overwrite default scope with "_scope_".

* add ut

* add ut
parent 7a5d3c83
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import copy
import time
from contextlib import contextmanager
from typing import Generator, Optional
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
......@@ -71,3 +74,17 @@ class DefaultScope(ManagerMixin):
instance = None
_release_lock()
return instance
@classmethod
@contextmanager
def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator:
"""overwrite the current default scope with `scope_name`"""
if scope_name is None:
yield
else:
tmp = copy.deepcopy(cls._instance_dict)
cls.get_instance(f'overwrite-{time.time()}', scope_name=scope_name)
try:
yield
finally:
cls._instance_dict = tmp
......@@ -468,7 +468,7 @@ class Registry:
return None
def build(self, *args, **kwargs) -> Any:
def build(self, cfg, *args, **kwargs) -> Any:
"""Build an instance.
Build an instance by calling :attr:`build_func`. If the global
......@@ -487,27 +487,28 @@ class Registry:
>>> cfg = dict(type='ResNet', depth=50)
>>> model = MODELS.build(cfg)
"""
# get the global default scope
default_scope = DefaultScope.get_current_instance()
if default_scope is not None:
scope_name = default_scope.scope_name
root = self._get_root_registry()
registry = root._search_child(scope_name)
if registry is None:
# if `default_scope` can not be found, fallback to use self
warnings.warn(
f'Failed to search registry with scope "{scope_name}" in '
f'the "{root.name}" registry tree. '
f'As a workaround, the current "{self.name}" registry in '
f'"{self.scope}" is used to build instance. This may '
f'cause unexpected failure when running the built '
f'modules. Please check whether "{scope_name}" is a '
f'correct scope, or whether the registry is initialized.')
with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
# get the global default scope
default_scope = DefaultScope.get_current_instance()
if default_scope is not None:
scope_name = default_scope.scope_name
root = self._get_root_registry()
registry = root._search_child(scope_name)
if registry is None:
# if `default_scope` can not be found, fallback to use self
warnings.warn(
f'Failed to search registry with scope "{scope_name}" '
f'in the "{root.name}" registry tree. '
f'As a workaround, the current "{self.name}" registry '
f'in "{self.scope}" is used to build instance. This '
f'may cause unexpected failure when running the built '
f'modules. Please check whether "{scope_name}" is a '
f'correct scope, or whether the registry is '
f'initialized.')
registry = self
else:
registry = self
else:
registry = self
return registry.build_func(*args, **kwargs, registry=registry)
return registry.build_func(cfg, *args, **kwargs, registry=registry)
def _add_child(self, registry: 'Registry') -> None:
"""Add a child for a registry.
......
......@@ -21,3 +21,15 @@ class TestDefaultScope:
DefaultScope.get_instance('instance_name', scope_name='mmengine')
default_scope = DefaultScope.get_current_instance()
assert default_scope.scope_name == 'mmengine'
def test_overwrite_default_scope(self):
origin_scope = DefaultScope.get_instance(
'test_overwrite_default_scope', scope_name='origin_scope')
with DefaultScope.overwrite_default_scope(scope_name=None):
assert DefaultScope.get_current_instance(
).scope_name == 'origin_scope'
with DefaultScope.overwrite_default_scope(scope_name='test_overwrite'):
assert DefaultScope.get_current_instance(
).scope_name == 'test_overwrite'
assert DefaultScope.get_current_instance(
).scope_name == origin_scope.scope_name == 'origin_scope'
......@@ -167,16 +167,17 @@ class TestRegistry:
registries = []
DOGS = Registry('dogs')
registries.append(DOGS)
HOUNDS = Registry('dogs', parent=DOGS, scope='hound')
HOUNDS = Registry('hounds', parent=DOGS, scope='hound')
registries.append(HOUNDS)
LITTLE_HOUNDS = Registry('dogs', parent=HOUNDS, scope='little_hound')
LITTLE_HOUNDS = Registry(
'little hounds', parent=HOUNDS, scope='little_hound')
registries.append(LITTLE_HOUNDS)
MID_HOUNDS = Registry('dogs', parent=HOUNDS, scope='mid_hound')
MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound')
registries.append(MID_HOUNDS)
SAMOYEDS = Registry('dogs', parent=DOGS, scope='samoyed')
SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed')
registries.append(SAMOYEDS)
LITTLE_SAMOYEDS = Registry(
'dogs', parent=SAMOYEDS, scope='little_samoyed')
'little samoyeds', parent=SAMOYEDS, scope='little_samoyed')
registries.append(LITTLE_SAMOYEDS)
return registries
......@@ -323,7 +324,7 @@ class TestRegistry:
# LITTLE_HOUNDS MID_HOUNDS LITTLE_SAMOYEDS
# (little_hound) (mid_hound) (little_samoyed)
registries = self._build_registry()
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS = registries[:4]
DOGS, HOUNDS, LITTLE_HOUNDS, MID_HOUNDS, SAMOYEDS = registries[:5]
@DOGS.register_module()
class GoldenRetriever:
......@@ -367,6 +368,37 @@ class TestRegistry:
dog = MID_HOUNDS.build(b_cfg)
assert isinstance(dog, Beagle)
# test overwrite default scope with `_scope_`
@SAMOYEDS.register_module()
class MySamoyed:
def __init__(self, friend):
self.friend = DOGS.build(friend)
@SAMOYEDS.register_module()
class YourSamoyed:
pass
s_cfg = cfg_type(
dict(
_scope_='samoyed',
type='MySamoyed',
friend=dict(type='hound.BloodHound')))
dog = DOGS.build(s_cfg)
assert isinstance(dog, MySamoyed)
assert isinstance(dog.friend, BloodHound)
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
s_cfg = cfg_type(
dict(
_scope_='samoyed',
type='MySamoyed',
friend=dict(type='YourSamoyed')))
dog = DOGS.build(s_cfg)
assert isinstance(dog, MySamoyed)
assert isinstance(dog.friend, YourSamoyed)
assert DefaultScope.get_current_instance().scope_name != 'samoyed'
def test_repr(self):
CATS = Registry('cat')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment