Skip to content
Snippets Groups Projects
Unverified Commit 563b4bad authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Feature] add defaut scope (#149)

* add defaut scope

* Fix docstring

* override get_current_instance method in DefaultScope

clean meta nameing

* remove default mmengine argument of DefaltScope

remove default mmengine argument of DefaltScope

remove default mmengine argument of DefaltScope

* Fix unit test

Fix unit test

* Fix example in docstring

* add explaination of DefaultScope
parent 10485841
No related branches found
No related tags found
No related merge requests found
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS, from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
...@@ -9,5 +10,6 @@ __all__ = [ ...@@ -9,5 +10,6 @@ __all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS' 'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
'DefaultScope'
] ]
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock
class DefaultScope(ManagerMixin):
"""Scope of current task used to reset the current registry, which can be
accessed globally.
Consider the case of reseting the current ``Resgitry`` by``default_scope``
in the internal module which cannot access runner directly, it is difficult
to get the ``default_scope`` defined in ``Runner``. However, if ``Runner``
created ``DefaultScope`` instance by given ``default_scope``, the internal
module can get ``default_scope`` by ``DefaultScope.get_current_instance``
everywhere.
Args:
name (str): Name of default scope for global access.
scope_name (str): Scope of current task.
Examples:
>>> from mmengine import MODELS
>>> # Define default scope in runner.
>>> DefaultScope.get_instance('task', scope_name='mmdet')
>>> # Get default scope globally.
>>> scope_name = DefaultScope.get_instance('task').scope_name
>>> # build model from cfg.
>>> model = MODELS.build(model_cfg, default_scope=scope_name)
"""
def __init__(self, name: str, scope_name: str):
super().__init__(name)
self._scope_name = scope_name
@property
def scope_name(self) -> str:
"""
Returns:
str: Get current scope.
"""
return self._scope_name
@classmethod
def get_current_instance(cls) -> Optional['DefaultScope']:
"""Get latest created default scope.
Since default_scope is an optional argument for ``Registry.build``.
``get_current_instance`` should return ``None`` if there is no
``DefaultScope`` created.
Examples:
>>> default_scope = DefaultScope.get_current_instance()
>>> # There is no `DefaultScope` created yet,
>>> # `get_current_instance` return `None`.
>>> default_scope = DefaultScope.get_instance(
>>> 'instance_name', scope_name='mmengine')
>>> default_scope.scope_name
mmengine
>>> default_scope = DefaultScope.get_current_instance()
>>> default_scope.scope_name
mmengine
Returns:
Optional[DefaultScope]: Return None If there has not been
``DefaultScope`` instance created yet, otherwise return the
latest created DefaultScope instance.
"""
_accquire_lock()
if cls._instance_dict:
instance = super(DefaultScope, cls).get_current_instance()
else:
instance = None
_release_lock()
return instance
...@@ -30,7 +30,8 @@ from mmengine.logging import MessageHub, MMLogger ...@@ -30,7 +30,8 @@ from mmengine.logging import MessageHub, MMLogger
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS) MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink from mmengine.utils import find_latest_checkpoint, is_list_of, symlink
from mmengine.visualization import ComposedWriter from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop from .base_loop import BaseLoop
...@@ -226,10 +227,6 @@ class Runner: ...@@ -226,10 +227,6 @@ class Runner:
else: else:
self.cfg = dict() self.cfg = dict()
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = default_scope
self._epoch = 0 self._epoch = 0
self._iter = 0 self._iter = 0
self._inner_iter = 0 self._inner_iter = 0
...@@ -305,6 +302,10 @@ class Runner: ...@@ -305,6 +302,10 @@ class Runner:
self.message_hub = self.build_message_hub(message_hub) self.message_hub = self.build_message_hub(message_hub)
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = DefaultScope.get_instance(
self._experiment_name, scope_name=default_scope)
self._load_from = load_from self._load_from = load_from
self._resume = resume self._resume = resume
...@@ -684,7 +685,8 @@ class Runner: ...@@ -684,7 +685,8 @@ class Runner:
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
return model return model
elif isinstance(model, dict): elif isinstance(model, dict):
return MODELS.build(model, default_scope=self.default_scope) return MODELS.build(
model, default_scope=self.default_scope.scope_name)
else: else:
raise TypeError('model should be a nn.Module object or dict, ' raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}') f'but got {model}')
...@@ -735,7 +737,7 @@ class Runner: ...@@ -735,7 +737,7 @@ class Runner:
else: else:
model = MODEL_WRAPPERS.build( model = MODEL_WRAPPERS.build(
model_wrapper_cfg, model_wrapper_cfg,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict(model=self.model)) default_args=dict(model=self.model))
return model return model
...@@ -759,7 +761,9 @@ class Runner: ...@@ -759,7 +761,9 @@ class Runner:
return optimizer return optimizer
elif isinstance(optimizer, dict): elif isinstance(optimizer, dict):
optimizer = build_optimizer( optimizer = build_optimizer(
self.model, optimizer, default_scope=self.default_scope) self.model,
optimizer,
default_scope=self.default_scope.scope_name)
return optimizer return optimizer
else: else:
raise TypeError('optimizer should be an Optimizer object or dict, ' raise TypeError('optimizer should be an Optimizer object or dict, '
...@@ -807,7 +811,7 @@ class Runner: ...@@ -807,7 +811,7 @@ class Runner:
param_schedulers.append( param_schedulers.append(
PARAM_SCHEDULERS.build( PARAM_SCHEDULERS.build(
_scheduler, _scheduler,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict(optimizer=self.optimizer))) default_args=dict(optimizer=self.optimizer)))
else: else:
raise TypeError( raise TypeError(
...@@ -844,7 +848,8 @@ class Runner: ...@@ -844,7 +848,8 @@ class Runner:
return evaluator return evaluator
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict): elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
return build_evaluator( return build_evaluator(
evaluator, default_scope=self.default_scope) # type: ignore evaluator,
default_scope=self.default_scope.scope_name) # type: ignore
else: else:
raise TypeError( raise TypeError(
'evaluator should be one of dict, list of dict, BaseEvaluator ' 'evaluator should be one of dict, list of dict, BaseEvaluator '
...@@ -886,7 +891,7 @@ class Runner: ...@@ -886,7 +891,7 @@ class Runner:
dataset_cfg = dataloader_cfg.pop('dataset') dataset_cfg = dataloader_cfg.pop('dataset')
if isinstance(dataset_cfg, dict): if isinstance(dataset_cfg, dict):
dataset = DATASETS.build( dataset = DATASETS.build(
dataset_cfg, default_scope=self.default_scope) dataset_cfg, default_scope=self.default_scope.scope_name)
else: else:
# fallback to raise error in dataloader # fallback to raise error in dataloader
# if `dataset_cfg` is not a valid type # if `dataset_cfg` is not a valid type
...@@ -897,7 +902,7 @@ class Runner: ...@@ -897,7 +902,7 @@ class Runner:
if isinstance(sampler_cfg, dict): if isinstance(sampler_cfg, dict):
sampler = DATA_SAMPLERS.build( sampler = DATA_SAMPLERS.build(
sampler_cfg, sampler_cfg,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict(dataset=dataset)) default_args=dict(dataset=dataset))
else: else:
# fallback to raise error in dataloader # fallback to raise error in dataloader
...@@ -966,7 +971,7 @@ class Runner: ...@@ -966,7 +971,7 @@ class Runner:
if 'type' in loop_cfg: if 'type' in loop_cfg:
loop = LOOPS.build( loop = LOOPS.build(
loop_cfg, loop_cfg,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict( default_args=dict(
runner=self, dataloader=self.train_dataloader)) runner=self, dataloader=self.train_dataloader))
else: else:
...@@ -1017,7 +1022,7 @@ class Runner: ...@@ -1017,7 +1022,7 @@ class Runner:
if 'type' in loop_cfg: if 'type' in loop_cfg:
loop = LOOPS.build( loop = LOOPS.build(
loop_cfg, loop_cfg,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict( default_args=dict(
runner=self, runner=self,
dataloader=self.val_dataloader, dataloader=self.val_dataloader,
...@@ -1064,7 +1069,7 @@ class Runner: ...@@ -1064,7 +1069,7 @@ class Runner:
if 'type' in loop_cfg: if 'type' in loop_cfg:
loop = LOOPS.build( loop = LOOPS.build(
loop_cfg, loop_cfg,
default_scope=self.default_scope, default_scope=self.default_scope.scope_name,
default_args=dict( default_args=dict(
runner=self, runner=self,
dataloader=self.test_dataloader, dataloader=self.test_dataloader,
......
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
import pytest
from mmengine.registry import DefaultScope
class TestDefaultScope:
def test_scope(self):
default_scope = DefaultScope.get_instance('name1', scope_name='mmdet')
assert default_scope.scope_name == 'mmdet'
# `DefaultScope.get_instance` must have `scope_name` argument.
with pytest.raises(TypeError):
DefaultScope.get_instance('name2')
def test_get_current_instance(self):
DefaultScope._instance_dict = OrderedDict()
assert DefaultScope.get_current_instance() is None
DefaultScope.get_instance('instance_name', scope_name='mmengine')
default_scope = DefaultScope.get_current_instance()
assert default_scope.scope_name == 'mmengine'
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment