From 7e2302388f3948470a71967a45d7b7e52ec03b3a Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 8 Aug 2022 21:01:06 +0800 Subject: [PATCH] [Feature] add config new feature (#105) --- mmengine/config/__init__.py | 3 +- mmengine/config/config.py | 144 ++++++++++++++++-- mmengine/config/get_config_model.py | 80 ++++++++++ mmengine/config/utils.py | 117 ++++++++++++++ mmengine/registry/registry.py | 9 ++ mmengine/utils/__init__.py | 5 +- mmengine/utils/package_utils.py | 71 +++++++++ .../config/py_config/test_get_external_cfg.py | 7 + .../py_config/test_get_external_cfg2.py | 2 + .../py_config/test_get_external_cfg3.py | 18 +++ .../py_config/test_get_external_cfg_base.py | 2 + tests/test_config/test_collect_meta.py | 43 ++++++ tests/test_config/test_config.py | 93 ++++++++++- tests/test_config/test_get_config_model.py | 51 +++++++ tests/test_registry/test_registry.py | 4 +- tests/test_visualizer/test_visualizer.py | 6 +- 16 files changed, 632 insertions(+), 23 deletions(-) create mode 100644 mmengine/config/get_config_model.py create mode 100644 mmengine/utils/package_utils.py create mode 100644 tests/data/config/py_config/test_get_external_cfg.py create mode 100644 tests/data/config/py_config/test_get_external_cfg2.py create mode 100644 tests/data/config/py_config/test_get_external_cfg3.py create mode 100644 tests/data/config/py_config/test_get_external_cfg_base.py create mode 100644 tests/test_config/test_collect_meta.py create mode 100644 tests/test_config/test_get_config_model.py diff --git a/mmengine/config/__init__.py b/mmengine/config/__init__.py index abf09ab2..bdf47067 100644 --- a/mmengine/config/__init__.py +++ b/mmengine/config/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .config import Config, ConfigDict, DictAction +from .get_config_model import get_config, get_model -__all__ = ['Config', 'ConfigDict', 'DictAction'] +__all__ = ['Config', 'ConfigDict', 'DictAction', 'get_config', 'get_model'] diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 0e065802..18c76943 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -12,14 +12,16 @@ import warnings from argparse import Action, ArgumentParser, Namespace from collections import abc from pathlib import Path -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union from addict import Dict from yapf.yapflib.yapf_api import FormatCode from mmengine.fileio import dump, load -from mmengine.utils import check_file_exist, import_modules_from_strings -from .utils import RemoveAssignFromAST +from mmengine.utils import (check_file_exist, check_install_package, + get_installed_path, import_modules_from_strings) +from .utils import (RemoveAssignFromAST, _get_external_cfg_base_path, + _get_external_cfg_path, _get_package_and_cfg_path) BASE_KEY = '_base_' DELETE_KEY = '_delete_' @@ -393,16 +395,26 @@ class Config: # Handle base files base_cfg_dict = ConfigDict() cfg_text_list = list() - for base_cfg_path in Config._parse_base_files( - temp_config_file.name): - cfg_dir = osp.dirname(filename) - _cfg_dict, _cfg_text = Config._file2dict( - osp.join(cfg_dir, base_cfg_path)) + for base_cfg_path in Config._get_base_files(temp_config_file.name): + base_cfg_path, scope = Config._get_cfg_path( + base_cfg_path, filename) + _cfg_dict, _cfg_text = Config._file2dict(base_cfg_path) cfg_text_list.append(_cfg_text) duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() if len(duplicate_keys) > 0: raise KeyError('Duplicate key is not allowed among bases. ' f'Duplicate keys: {duplicate_keys}') + + # _dict_to_config_dict will do the following things: + # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. + # 2. Set `_scope_` for the outer dict variable for the base + # config. + # 3. Set `scope` attribute for each base variable. Different + # from `_scope_`, `scope` is not a key of base dict, + # `scope` attribute will be parsed to key `_scope_` by + # function `_parse_scope` only if the base variable is + # accessed by the current config. + _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) base_cfg_dict.update(_cfg_dict) if filename.endswith('.py'): @@ -428,6 +440,11 @@ class Config: cfg_dict.pop(key) temp_config_file.close() + # If the current config accesses a base variable of base + # configs, The ``scope`` attribute of corresponding variable + # will be converted to the `_scope_`. + Config._parse_scope(cfg_dict) + # check deprecation information if DEPRECATION_KEY in cfg_dict: deprecation_info = cfg_dict.pop(DEPRECATION_KEY) @@ -464,19 +481,76 @@ class Config: return cfg_dict, cfg_text @staticmethod - def _parse_base_files(file_path: str) -> List[str]: - """Get paths of all base config files. + def _dict_to_config_dict(cfg: dict, + scope: Optional[str] = None, + has_scope=True): + """Recursively converts ``dict`` to :obj:`ConfigDict`. Args: - file_path (str): Path of config. + cfg (dict): Config dict. + scope (str, optional): Scope of instance. + has_scope (bool): Whether to add `_scope_` key to config dict. Returns: - List[str]: paths of all base files . + ConfigDict: Converted dict. + """ + # Only the outer dict with key `type` should have the key `_scope_`. + if isinstance(cfg, dict): + if has_scope and 'type' in cfg: + has_scope = False + if scope is not None and cfg.get('_scope_', None) is None: + cfg._scope_ = scope # type: ignore + cfg = ConfigDict(cfg) + dict.__setattr__(cfg, 'scope', scope) + for key, value in cfg.items(): + cfg[key] = Config._dict_to_config_dict( + value, scope=scope, has_scope=has_scope) + elif isinstance(cfg, tuple): + cfg = tuple( + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg) + elif isinstance(cfg, list): + cfg = [ + Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) + for _cfg in cfg + ] + return cfg + + @staticmethod + def _parse_scope(cfg: dict) -> None: + """Adds ``_scope_`` to :obj:`ConfigDict` instance, which means a base + variable. + + If the config dict already has the scope, scope will not be + overwritten. + + Args: + cfg (dict): Config needs to be parsed with scope. """ - file_format = file_path.partition('.')[-1] + if isinstance(cfg, ConfigDict): + cfg._scope_ = cfg.scope + elif isinstance(cfg, (tuple, list)): + [Config._parse_scope(value) for value in cfg] + else: + return + + @staticmethod + def _get_base_files(filename: str) -> list: + """Get the base config file. + + Args: + filename (str): The config file. + + Raises: + TypeError: Name of config file. + + Returns: + list: A list of base config + """ + file_format = filename.partition('.')[-1] if file_format == 'py': - Config._validate_py_syntax(file_path) - with open(file_path) as f: + Config._validate_py_syntax(filename) + with open(filename) as f: codes = ast.parse(f.read()).body def is_base_line(c): @@ -492,7 +566,8 @@ class Config: else: base_files = [] elif file_format in ('yml', 'yaml', 'json'): - cfg_dict = load(file_path) + import mmengine + cfg_dict = mmengine.load(filename) base_files = cfg_dict.get(BASE_KEY, []) else: raise TypeError('The config type should be py, json, yaml or ' @@ -501,6 +576,43 @@ class Config: list) else [base_files] return base_files + @staticmethod + def _get_cfg_path(cfg_path: str, + filename: str) -> Tuple[str, Optional[str]]: + """Get the config path from the current or external package. + + Args: + cfg_path (str): Relative path of config. + filename (str): The config file being parsed. + + Returns: + Tuple[str, str or None]: Path and scope of config. If the config + is not an external config, the scope will be `None`. + """ + if '::' in cfg_path: + # `cfg_path` startswith '::' means an external config path. + # Get package name and relative config path. + scope = cfg_path.partition('::')[0] + package, cfg_path = _get_package_and_cfg_path(cfg_path) + # Get installed package path. + check_install_package(package) + package_path = get_installed_path(package) + try: + # Get config path from meta file. + cfg_path = _get_external_cfg_path(package_path, cfg_path) + except ValueError: + # Since base config does not have a metafile, it should be + # concatenated with package path and relative config path. + cfg_path = _get_external_cfg_base_path(package_path, cfg_path) + except FileNotFoundError as e: + raise e + return cfg_path, scope + else: + # Get local config path. + cfg_dir = osp.dirname(filename) + cfg_path = osp.join(cfg_dir, cfg_path) + return cfg_path, None + @staticmethod def _merge_a_into_b(a: dict, b: dict, diff --git a/mmengine/config/get_config_model.py b/mmengine/config/get_config_model.py new file mode 100644 index 00000000..14b5936a --- /dev/null +++ b/mmengine/config/get_config_model.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os.path as osp + +import torch.nn as nn + +from mmengine.registry import MODELS, DefaultScope +from mmengine.utils import check_install_package, get_installed_path +from .config import Config +from .utils import (_get_cfg_metainfo, _get_external_cfg_base_path, + _get_package_and_cfg_path) + + +def get_config(cfg_path: str, pretrained: bool = False) -> Config: + """Get config from external package. + + Args: + cfg_path (str): External relative config path. + pretrained (bool): Whether to save pretrained model path. If + ``pretrained==True``, the url of pretrained model can be accessed + by ``cfg.model_path``. Defaults to False. + + Examples: + >>> cfg = get_config('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco', + >>> pretrained=True) + >>> # Equivalent to + >>> Config.fromfile('/path/to/faster_rcnn_r50_fpn_1x_coco.py') + >>> cfg.model_path + https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth + + Returns: + Config: A `Config` parsed from external package. + """ # noqa E301 + # Get package name and relative config path. + package, cfg_path = _get_package_and_cfg_path(cfg_path) + # Check package is installed. + check_install_package(package) + package_path = get_installed_path(package) + try: + # Use `cfg_path` to search target config file. + cfg_meta = _get_cfg_metainfo(package_path, cfg_path) + cfg_path = osp.join(package_path, '.mim', cfg_meta['Config']) + cfg = Config.fromfile(cfg_path) + if pretrained: + assert 'Weights' in cfg_meta, ('Cannot find `Weights` in cfg_file' + '.metafile.yml, please check the' + 'metafile') + cfg.model_path = cfg_meta['Weights'] + except ValueError: + # Since the base config does not contain a metafile, the absolute + # config is `osp.join(package_path, cfg_path_prefix, cfg_name)` + cfg_path = _get_external_cfg_base_path(package_path, cfg_path) + cfg = Config.fromfile(cfg_path) + except Exception as e: + raise e + return cfg + + +def get_model(cfg_path: str, pretrained: bool = False, **kwargs) -> nn.Module: + """Get built model from external package. + + Args: + cfg_path (str): External relative config path with prefix + 'package::' and without suffix. + pretrained (bool): Whether to load pretrained model. Defaults to False. + kwargs (dict): Default arguments to build model. + + Returns: + nn.Module: Built model. + """ + import mmengine.runner + package = cfg_path.split('::')[0] + with DefaultScope.overwrite_default_scope(package): # type: ignore + cfg = get_config(cfg_path, pretrained) + models_module = importlib.import_module(f'{package}.utils') + models_module.register_all_modules() # type: ignore + model = MODELS.build(cfg.model, default_args=kwargs) + if pretrained: + mmengine.runner.load_checkpoint(model, cfg.model_path) + return model diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py index 54c32691..7eec2c65 100644 --- a/mmengine/config/utils.py +++ b/mmengine/config/utils.py @@ -1,5 +1,122 @@ # Copyright (c) OpenMMLab. All rights reserved. import ast +import os.path as osp +import re +import warnings +from typing import Tuple + +from mmengine.fileio import load +from mmengine.utils import check_file_exist + +PKG2PROJECT = { + 'mmcls': 'mmcls', + 'mmdet': 'mmdet', + 'mmdet3d': 'mmdet3d', + 'mmseg': 'mmsegmentation', + 'mmaction2': 'mmaction2', + 'mmtrack': 'mmtrack', + 'mmpose': 'mmpose', + 'mmedit': 'mmedit', + 'mmocr': 'mmocr', + 'mmgen': 'mmgen', + 'mmfewshot': 'mmfewshot', + 'mmrazor': 'mmrazor', + 'mmflow': 'mmflow', + 'mmhuman3d': 'mmhuman3d', + 'mmrotate': 'mmrotate', + 'mmselfsup': 'mmselfsup', +} + + +def _get_cfg_metainfo(package_path: str, cfg_path: str) -> dict: + """Get target meta information from all 'metafile.yml' defined in `mode- + index.yml` of external package. + + Args: + package_path (str): Path of external package. + cfg_path (str): Name of experiment config. + + Returns: + dict: Meta information of target experiment. + """ + meta_index_path = osp.join(package_path, '.mim', 'model-index.yml') + meta_index = load(meta_index_path) + cfg_dict = dict() + for meta_path in meta_index['Import']: + meta_path = osp.join(package_path, '.mim', meta_path) + cfg_meta = load(meta_path) + for model_cfg in cfg_meta['Models']: + if 'Config' not in model_cfg: + warnings.warn(f'There is not `Config` define in {model_cfg}') + continue + cfg_name = model_cfg['Config'].partition('/')[-1] + # Some config could have multiple weights, we only pick the + # first one. + if cfg_name in cfg_dict: + continue + cfg_dict[cfg_name] = model_cfg + if cfg_path not in cfg_dict: + raise ValueError(f'Expected configs: {cfg_dict.keys()}, but got ' + f'{cfg_path}') + return cfg_dict[cfg_path] + + +def _get_external_cfg_path(package_path: str, cfg_file: str) -> str: + """Get config path of external package. + + Args: + package_path (str): Path of external package. + cfg_file (str): Name of experiment config. + + Returns: + str: Absolute config path from external package. + """ + cfg_file = cfg_file.split('.')[0] + model_cfg = _get_cfg_metainfo(package_path, cfg_file) + cfg_path = osp.join(package_path, model_cfg['Config']) + check_file_exist(cfg_path) + return cfg_path + + +def _get_external_cfg_base_path(package_path: str, cfg_name: str) -> str: + """Get base config path of external package. + + Args: + package_path (str): Path of external package. + cfg_name (str): External relative config path with 'package::'. + + Returns: + str: Absolute config path from external package. + """ + cfg_path = osp.join(package_path, '.mim', 'configs', cfg_name) + check_file_exist(cfg_path) + return cfg_path + + +def _get_package_and_cfg_path(cfg_path: str) -> Tuple[str, str]: + """Get package name and relative config path. + + Args: + cfg_path (str): External relative config path with 'package::'. + + Returns: + Tuple[str, str]: Package name and config path. + """ + if re.match(r'\w*::\w*/\w*', cfg_path) is None: + raise ValueError( + '`_get_package_and_cfg_path` is used for get external package, ' + 'please specify the package name and relative config path, just ' + 'like `mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py`') + package_cfg = cfg_path.split('::') + if len(package_cfg) > 2: + raise ValueError('`::` should only be used to separate package and ' + 'config name, but found multiple `::` in ' + f'{cfg_path}') + package, cfg_path = package_cfg + assert package in PKG2PROJECT, 'mmengine does not support to load ' \ + f'{package} config.' + package = PKG2PROJECT[package] + return package, cfg_path class RemoveAssignFromAST(ast.NodeTransformer): diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 351aabee..5fc4e56f 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -4,8 +4,10 @@ import logging import sys from collections.abc import Callable from contextlib import contextmanager +from importlib import import_module from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union +from ..config.utils import PKG2PROJECT from ..utils import is_seq_of from .default_scope import DefaultScope @@ -239,6 +241,13 @@ class Registry: # Get registry by scope if default_scope is not None: scope_name = default_scope.scope_name + if scope_name in PKG2PROJECT: + try: + module = import_module( + f'{PKG2PROJECT[scope_name]}.utils') + module.register_all_modules() # type: ignore + except ImportError as e: + raise e root = self._get_root_registry() registry = root._search_child(scope_name) if registry is None: diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index df080378..51ad2e41 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -8,6 +8,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, tuple_cast) +from .package_utils import (call_command, check_install_package, + get_installed_path, is_installed) from .parrots_wrapper import TORCH_VERSION from .path import (check_file_exist, fopen, is_abs, is_filepath, mkdir_or_exist, scandir, symlink) @@ -28,5 +30,6 @@ __all__ = [ 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'ManagerMeta', 'ManagerMixin', 'set_multi_processing', 'has_batch_norm', - 'is_abs', 'revert_sync_batchnorm' + 'is_abs', 'is_installed', 'call_command', 'get_installed_path', + 'check_install_package', 'is_abs', 'revert_sync_batchnorm' ] diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py new file mode 100644 index 00000000..dfd41eb1 --- /dev/null +++ b/mmengine/utils/package_utils.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os.path as osp +import subprocess + +import pkg_resources +from pkg_resources import get_distribution + + +def is_installed(package: str) -> bool: + """Check package whether installed. + + Args: + package (str): Name of package to be checked. + """ + # refresh the pkg_resources + # more datails at https://github.com/pypa/setuptools/issues/373 + importlib.reload(pkg_resources) + try: + get_distribution(package) + return True + except pkg_resources.DistributionNotFound: + return False + + +def get_installed_path(package: str) -> str: + """Get installed path of package. + + Args: + package (str): Name of package. + + Example: + >>> get_installed_path('mmcls') + >>> '.../lib/python3.7/site-packages/mmcls' + """ + # if the package name is not the same as module name, module name should be + # inferred. For example, mmcv-full is the package name, but mmcv is module + # name. If we want to get the installed path of mmcv-full, we should concat + # the pkg.location and module name + pkg = get_distribution(package) + possible_path = osp.join(pkg.location, package) + if osp.exists(possible_path): + return possible_path + else: + return osp.join(pkg.location, package2module(package)) + + +def package2module(package: str): + """Infer module name from package. + + Args: + package (str): Package to infer module name. + """ + pkg = get_distribution(package) + if pkg.has_metadata('top_level.txt'): + module_name = pkg.get_metadata('top_level.txt').split('\n')[0] + return module_name + else: + raise ValueError(f'can not infer the module name of {package}') + + +def call_command(cmd: list) -> None: + try: + subprocess.check_call(cmd) + except Exception as e: + raise e # type: ignore + + +def check_install_package(package: str): + if not is_installed(package): + call_command(['python', '-m', 'pip', 'install', package]) diff --git a/tests/data/config/py_config/test_get_external_cfg.py b/tests/data/config/py_config/test_get_external_cfg.py new file mode 100644 index 00000000..51a72f30 --- /dev/null +++ b/tests/data/config/py_config/test_get_external_cfg.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = [ + 'mmdet::_base_/models/faster_rcnn_r50_fpn.py', + 'mmdet::_base_/datasets/coco_detection.py', + 'mmdet::_base_/schedules/schedule_1x.py', + 'mmdet::_base_/default_runtime.py' +] diff --git a/tests/data/config/py_config/test_get_external_cfg2.py b/tests/data/config/py_config/test_get_external_cfg2.py new file mode 100644 index 00000000..a2d60cee --- /dev/null +++ b/tests/data/config/py_config/test_get_external_cfg2.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' diff --git a/tests/data/config/py_config/test_get_external_cfg3.py b/tests/data/config/py_config/test_get_external_cfg3.py new file mode 100644 index 00000000..7313b803 --- /dev/null +++ b/tests/data/config/py_config/test_get_external_cfg3.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = [ + 'mmdet::_base_/models/faster_rcnn_r50_fpn.py', + 'mmdet::_base_/datasets/coco_detection.py', + 'mmdet::_base_/schedules/schedule_1x.py', + 'mmdet::_base_/default_runtime.py', + './test_get_external_cfg_base.py' +] + +custom_hooks = [dict(type='mmdet.DetVisualizationHook')] + +model = dict( + roi_head=dict( + bbox_head=dict( + loss_cls=dict(_delete_=True, type='test.ToyLoss') + ) + ) +) diff --git a/tests/data/config/py_config/test_get_external_cfg_base.py b/tests/data/config/py_config/test_get_external_cfg_base.py new file mode 100644 index 00000000..d680ef0a --- /dev/null +++ b/tests/data/config/py_config/test_get_external_cfg_base.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +toy_model = dict(type='ToyModel') diff --git a/tests/test_config/test_collect_meta.py b/tests/test_config/test_collect_meta.py new file mode 100644 index 00000000..f7b5e73a --- /dev/null +++ b/tests/test_config/test_collect_meta.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path + +import pytest + +from mmengine.config.utils import (_get_external_cfg_base_path, + _get_package_and_cfg_path) + + +def test_get_external_cfg_base_path(tmp_path): + package_path = tmp_path + rel_cfg_path = 'cfg_dir/cfg_file' + with pytest.raises(FileNotFoundError): + _get_external_cfg_base_path(str(package_path), rel_cfg_path) + cfg_dir = tmp_path / '.mim' / 'configs' / 'cfg_dir' + cfg_dir.mkdir(parents=True, exist_ok=True) + f = open(cfg_dir / 'cfg_file', 'w') + f.close() + cfg_path = _get_external_cfg_base_path(str(package_path), rel_cfg_path) + assert cfg_path == f'{os.path.join(str(cfg_dir), "cfg_file")}' + + +def test_get_external_cfg_path(): + external_cfg_path = 'mmdet::path/cfg' + package, rel_cfg_path = _get_package_and_cfg_path(external_cfg_path) + assert package == 'mmdet' + assert rel_cfg_path == 'path/cfg' + # external config must contain `::`. + external_cfg_path = 'path/cfg' + with pytest.raises(ValueError): + _get_package_and_cfg_path(external_cfg_path) + # Use `:::` as operator will raise an error. + external_cfg_path = 'mmdet:::path/cfg' + with pytest.raises(ValueError): + _get_package_and_cfg_path(external_cfg_path) + # Use `:` as operator will raise an error. + external_cfg_path = 'mmdet:path/cfg' + with pytest.raises(ValueError): + _get_package_and_cfg_path(external_cfg_path) + # Too much `::` + external_cfg_path = 'mmdet::path/cfg::error' + with pytest.raises(ValueError): + _get_package_and_cfg_path(external_cfg_path) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 671bdc12..8c33856c 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -12,6 +12,8 @@ import pytest from mmengine import Config, ConfigDict, DictAction from mmengine.fileio import dump, load +from mmengine.registry import MODELS, DefaultScope, Registry +from mmengine.utils import is_installed class TestConfig: @@ -186,6 +188,21 @@ class TestConfig: # overwritten by int sys.argv.extend(tmp) + def test_dict_to_config_dict(self): + cfg_dict = dict( + a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))]) + cfg_dict = Config._dict_to_config_dict(cfg_dict) + assert isinstance(cfg_dict, ConfigDict) + assert isinstance(cfg_dict.a, int) + assert isinstance(cfg_dict.b, ConfigDict) + assert isinstance(cfg_dict.b.c, ConfigDict) + assert isinstance(cfg_dict.d, list) + assert isinstance(cfg_dict.d[0], ConfigDict) + assert isinstance(cfg_dict.d[0].e, ConfigDict) + assert isinstance(cfg_dict.d[0].e.f, tuple) + assert isinstance(cfg_dict.d[0].e.f[0], ConfigDict) + assert isinstance(cfg_dict.d[0].e.f[1], list) + def test_dump(self, tmp_path): file_path = 'config/py_config/test_merge_from_multiple_bases.py' cfg_file = osp.join(self.data_path, file_path) @@ -266,7 +283,7 @@ class TestConfig: print(cfg, file=f) with open(tmp_txt) as f: assert f.read().strip() == f'Config (path: {cfg.filename}): ' \ - f'{cfg._cfg_dict.__repr__()}' + f'{cfg._cfg_dict.__repr__()}' def test_dict_action(self): parser = argparse.ArgumentParser(description='Train a detector') @@ -418,6 +435,26 @@ class TestConfig: self._merge_recursive_bases() self._deprecation() + def test_get_cfg_path(self): + filename = 'py_config/simple_config.py' + filename = osp.join(self.data_path, 'config', filename) + cfg_name = './base.py' + cfg_path, scope = Config._get_cfg_path(cfg_name, filename) + assert scope is None + osp.isfile(cfg_path) + + # Test scope equal to package name. + cfg_name = 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' + cfg_path, scope = Config._get_cfg_path(cfg_name, filename) + assert scope == 'mmdet' + osp.isfile(cfg_path) + + # Test scope does not equal to package name. + cfg_name = 'mmcls::cspnet/cspresnet50_8xb32_in1k.py' + cfg_path, scope = Config._get_cfg_path(cfg_name, filename) + assert scope == 'mmcls' + osp.isfile(cfg_path) + def _simple_load(self): # test load simple config for file_format in ['py', 'json', 'yaml']: @@ -750,3 +787,57 @@ class TestConfig: assert new_cfg._cfg_dict is cfg._cfg_dict assert new_cfg._filename == cfg._filename assert new_cfg._text == cfg._text + + def test_get_external_cfg(self): + ext_cfg_path = osp.join(self.data_path, + 'config/py_config/test_get_external_cfg.py') + ext_cfg = Config.fromfile(ext_cfg_path) + assert ext_cfg._cfg_dict.model.neck == dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5, + ) + assert '_scope_' in ext_cfg._cfg_dict.model + + @pytest.mark.skipif( + not is_installed('mmdet'), reason='mmdet should be installed') + def test_build_external_package(self): + # Test load base config. + ext_cfg_path = osp.join(self.data_path, + 'config/py_config/test_get_external_cfg.py') + ext_cfg = Config.fromfile(ext_cfg_path) + + LOCAL_MODELS = Registry('local_model', parent=MODELS, scope='test') + LOCAL_MODELS.build(ext_cfg.model) + + # Test load non-base config + ext_cfg_path = osp.join(self.data_path, + 'config/py_config/test_get_external_cfg2.py') + ext_cfg = Config.fromfile(ext_cfg_path) + LOCAL_MODELS.build(ext_cfg.model) + + # Test override base variable. + ext_cfg_path = osp.join(self.data_path, + 'config/py_config/test_get_external_cfg3.py') + ext_cfg = Config.fromfile(ext_cfg_path) + + @LOCAL_MODELS.register_module() + class ToyLoss: + pass + + @LOCAL_MODELS.register_module() + class ToyModel: + pass + + DefaultScope.get_instance('test1', scope_name='test') + assert ext_cfg.model._scope_ == 'mmdet' + model = LOCAL_MODELS.build(ext_cfg.model) + + # Local base config should not have scope. + assert '_scope_' not in ext_cfg.toy_model + toy_model = LOCAL_MODELS.build(ext_cfg.toy_model) + assert isinstance(toy_model, ToyModel) + assert model.backbone.style == 'pytorch' + assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss) + DefaultScope._instance_dict.pop('test1') diff --git a/tests/test_config/test_get_config_model.py b/tests/test_config/test_get_config_model.py new file mode 100644 index 00000000..812ab824 --- /dev/null +++ b/tests/test_config/test_get_config_model.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import pytest + +from mmengine import Config, DefaultScope, get_config, get_model +from mmengine.utils import get_installed_path, is_installed + +data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data/') + + +# mmdet has a more typical config structure, while mmpose has a complex +# config structure +@pytest.mark.skipif( + not (is_installed('mmdet') and is_installed('mmpose')), + reason='mmdet and mmpose should be installed') +def test_get_config(): + # Test load base config. + base_cfg = get_config('mmdet::_base_/models/faster_rcnn_r50_fpn.py') + package_path = get_installed_path('mmdet') + test_base_cfg = Config.fromfile( + osp.join(package_path, '.mim', + 'configs/_base_/models/faster_rcnn_r50_fpn.py')) + assert test_base_cfg._cfg_dict == base_cfg._cfg_dict + + # Test load faster_rcnn config + cfg = get_config('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py') + test_cfg = Config.fromfile( + osp.join(package_path, '.mim', + 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')) + assert cfg._cfg_dict == test_cfg._cfg_dict + + # Test pretrained + cfg = get_config( + 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py', pretrained=True) + assert cfg.model_path == 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa E301 + + # Test load mmpose + get_config( + 'mmpose::face/2d_kpt_sview_rgb_img/deeppose/wflw/res50_wflw_256x256' + '.py') + + +@pytest.mark.skipif( + not is_installed('mmdet'), reason='mmdet and mmpose should be installed') +def test_get_model(): + # TODO compatible with downstream codebase. + DefaultScope.get_instance('test_get_model', scope_name='test_scope') + get_model('mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py') + assert DefaultScope.get_current_instance().scope_name == 'test_scope' + DefaultScope._instance_dict.pop('test_get_model') diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 74e51e22..d11fac13 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -193,7 +193,7 @@ class TestRegistry: return registries - def test_get_root_registry(self): + def test__get_root_registry(self): # Hierarchical Registry # DOGS # _______|_______ @@ -304,7 +304,7 @@ class TestRegistry: assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None - def test_search_child(self): + def test__search_child(self): # Hierarchical Registry # DOGS # _______|_______ diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index cc8fa2b8..a487501f 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import time from typing import Any from unittest import TestCase @@ -117,12 +118,13 @@ class TestVisualizer(TestCase): save_dir='temp_dir') # test global init + instance_name = 'visualizer' + str(time.time()) visualizer = Visualizer.get_instance( - 'visualizer', + instance_name, vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir='temp_dir') assert len(visualizer._vis_backends) == 2 - visualizer_any = Visualizer.get_instance('visualizer') + visualizer_any = Visualizer.get_instance(instance_name) assert visualizer_any == visualizer def test_set_image(self): -- GitLab