From 563b4bad165a0f2e4b721f8cc5b48ef4f7a48edc Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 28 Mar 2022 23:14:41 +0800 Subject: [PATCH] [Feature] add defaut scope (#149) * add defaut scope * Fix docstring * override get_current_instance method in DefaultScope clean meta nameing * remove default mmengine argument of DefaltScope remove default mmengine argument of DefaltScope remove default mmengine argument of DefaltScope * Fix unit test Fix unit test * Fix example in docstring * add explaination of DefaultScope --- mmengine/registry/__init__.py | 4 +- mmengine/registry/default_scope.py | 75 +++++++++++++++++++++++ mmengine/runner/runner.py | 35 ++++++----- tests/test_registry/test_default_scope.py | 23 +++++++ 4 files changed, 121 insertions(+), 16 deletions(-) create mode 100644 mmengine/registry/default_scope.py create mode 100644 tests/test_registry/test_default_scope.py diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 069d437a..2299c17e 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .default_scope import DefaultScope from .registry import Registry, build_from_cfg from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, @@ -9,5 +10,6 @@ __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS' + 'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', + 'DefaultScope' ] diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py new file mode 100644 index 00000000..7d221bef --- /dev/null +++ b/mmengine/registry/default_scope.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +from mmengine.utils.manager import ManagerMixin, _accquire_lock, _release_lock + + +class DefaultScope(ManagerMixin): + """Scope of current task used to reset the current registry, which can be + accessed globally. + + Consider the case of reseting the current ``Resgitry`` by``default_scope`` + in the internal module which cannot access runner directly, it is difficult + to get the ``default_scope`` defined in ``Runner``. However, if ``Runner`` + created ``DefaultScope`` instance by given ``default_scope``, the internal + module can get ``default_scope`` by ``DefaultScope.get_current_instance`` + everywhere. + + Args: + name (str): Name of default scope for global access. + scope_name (str): Scope of current task. + + Examples: + >>> from mmengine import MODELS + >>> # Define default scope in runner. + >>> DefaultScope.get_instance('task', scope_name='mmdet') + >>> # Get default scope globally. + >>> scope_name = DefaultScope.get_instance('task').scope_name + >>> # build model from cfg. + >>> model = MODELS.build(model_cfg, default_scope=scope_name) + """ + + def __init__(self, name: str, scope_name: str): + super().__init__(name) + self._scope_name = scope_name + + @property + def scope_name(self) -> str: + """ + Returns: + str: Get current scope. + """ + return self._scope_name + + @classmethod + def get_current_instance(cls) -> Optional['DefaultScope']: + """Get latest created default scope. + + Since default_scope is an optional argument for ``Registry.build``. + ``get_current_instance`` should return ``None`` if there is no + ``DefaultScope`` created. + + Examples: + >>> default_scope = DefaultScope.get_current_instance() + >>> # There is no `DefaultScope` created yet, + >>> # `get_current_instance` return `None`. + >>> default_scope = DefaultScope.get_instance( + >>> 'instance_name', scope_name='mmengine') + >>> default_scope.scope_name + mmengine + >>> default_scope = DefaultScope.get_current_instance() + >>> default_scope.scope_name + mmengine + + Returns: + Optional[DefaultScope]: Return None If there has not been + ``DefaultScope`` instance created yet, otherwise return the + latest created DefaultScope instance. + """ + _accquire_lock() + if cls._instance_dict: + instance = super(DefaultScope, cls).get_current_instance() + else: + instance = None + _release_lock() + return instance diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 42f61ced..6d458c02 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,7 +30,8 @@ from mmengine.logging import MessageHub, MMLogger from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, - MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS) + MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, + DefaultScope) from mmengine.utils import find_latest_checkpoint, is_list_of, symlink from mmengine.visualization import ComposedWriter from .base_loop import BaseLoop @@ -226,10 +227,6 @@ class Runner: else: self.cfg = dict() - # Used to reset registries location. See :meth:`Registry.build` for - # more details. - self.default_scope = default_scope - self._epoch = 0 self._iter = 0 self._inner_iter = 0 @@ -305,6 +302,10 @@ class Runner: self.message_hub = self.build_message_hub(message_hub) # writer used for writing log or visualizing all kinds of data self.writer = self.build_writer(writer) + # Used to reset registries location. See :meth:`Registry.build` for + # more details. + self.default_scope = DefaultScope.get_instance( + self._experiment_name, scope_name=default_scope) self._load_from = load_from self._resume = resume @@ -684,7 +685,8 @@ class Runner: if isinstance(model, nn.Module): return model elif isinstance(model, dict): - return MODELS.build(model, default_scope=self.default_scope) + return MODELS.build( + model, default_scope=self.default_scope.scope_name) else: raise TypeError('model should be a nn.Module object or dict, ' f'but got {model}') @@ -735,7 +737,7 @@ class Runner: else: model = MODEL_WRAPPERS.build( model_wrapper_cfg, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict(model=self.model)) return model @@ -759,7 +761,9 @@ class Runner: return optimizer elif isinstance(optimizer, dict): optimizer = build_optimizer( - self.model, optimizer, default_scope=self.default_scope) + self.model, + optimizer, + default_scope=self.default_scope.scope_name) return optimizer else: raise TypeError('optimizer should be an Optimizer object or dict, ' @@ -807,7 +811,7 @@ class Runner: param_schedulers.append( PARAM_SCHEDULERS.build( _scheduler, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict(optimizer=self.optimizer))) else: raise TypeError( @@ -844,7 +848,8 @@ class Runner: return evaluator elif isinstance(evaluator, dict) or is_list_of(evaluator, dict): return build_evaluator( - evaluator, default_scope=self.default_scope) # type: ignore + evaluator, + default_scope=self.default_scope.scope_name) # type: ignore else: raise TypeError( 'evaluator should be one of dict, list of dict, BaseEvaluator ' @@ -886,7 +891,7 @@ class Runner: dataset_cfg = dataloader_cfg.pop('dataset') if isinstance(dataset_cfg, dict): dataset = DATASETS.build( - dataset_cfg, default_scope=self.default_scope) + dataset_cfg, default_scope=self.default_scope.scope_name) else: # fallback to raise error in dataloader # if `dataset_cfg` is not a valid type @@ -897,7 +902,7 @@ class Runner: if isinstance(sampler_cfg, dict): sampler = DATA_SAMPLERS.build( sampler_cfg, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict(dataset=dataset)) else: # fallback to raise error in dataloader @@ -966,7 +971,7 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.train_dataloader)) else: @@ -1017,7 +1022,7 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.val_dataloader, @@ -1064,7 +1069,7 @@ class Runner: if 'type' in loop_cfg: loop = LOOPS.build( loop_cfg, - default_scope=self.default_scope, + default_scope=self.default_scope.scope_name, default_args=dict( runner=self, dataloader=self.test_dataloader, diff --git a/tests/test_registry/test_default_scope.py b/tests/test_registry/test_default_scope.py new file mode 100644 index 00000000..89a894e2 --- /dev/null +++ b/tests/test_registry/test_default_scope.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict + +import pytest + +from mmengine.registry import DefaultScope + + +class TestDefaultScope: + + def test_scope(self): + default_scope = DefaultScope.get_instance('name1', scope_name='mmdet') + assert default_scope.scope_name == 'mmdet' + # `DefaultScope.get_instance` must have `scope_name` argument. + with pytest.raises(TypeError): + DefaultScope.get_instance('name2') + + def test_get_current_instance(self): + DefaultScope._instance_dict = OrderedDict() + assert DefaultScope.get_current_instance() is None + DefaultScope.get_instance('instance_name', scope_name='mmengine') + default_scope = DefaultScope.get_current_instance() + assert default_scope.scope_name == 'mmengine' -- GitLab