Skip to content
Snippets Groups Projects
Unverified Commit 2d3bff44 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix unit test of `Config` will install `mmdet` and `mmcls`. (#492)


* fix unit test install mmdet and mmcls

* raise error when mmdet is not installed

* rename check_and_install to install_package

* split test case

Co-authored-by: default avatarC1rN09 <zhaoqian@pjlab.org.cn>
parent 521f375e
No related branches found
No related tags found
No related merge requests found
......@@ -18,8 +18,8 @@ from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from mmengine.fileio import dump, load
from mmengine.utils import (check_file_exist, check_install_package,
get_installed_path, import_modules_from_strings)
from mmengine.utils import (check_file_exist, get_installed_path,
import_modules_from_strings, is_installed)
from .utils import (RemoveAssignFromAST, _get_external_cfg_base_path,
_get_external_cfg_path, _get_package_and_cfg_path)
......@@ -594,8 +594,13 @@ class Config:
# Get package name and relative config path.
scope = cfg_path.partition('::')[0]
package, cfg_path = _get_package_and_cfg_path(cfg_path)
if not is_installed(package):
raise ModuleNotFoundError(
f'{package} is not installed, please install {package} '
f'manually')
# Get installed package path.
check_install_package(package)
package_path = get_installed_path(package)
try:
# Get config path from meta file.
......
......@@ -8,7 +8,7 @@ from mmengine.config.utils import (_get_cfg_metainfo,
_get_package_and_cfg_path)
from mmengine.registry import MODELS, DefaultScope
from mmengine.runner import load_checkpoint
from mmengine.utils import check_install_package, get_installed_path
from mmengine.utils import get_installed_path, install_package
def get_config(cfg_path: str, pretrained: bool = False) -> Config:
......@@ -32,8 +32,8 @@ def get_config(cfg_path: str, pretrained: bool = False) -> Config:
""" # noqa E301
# Get package name and relative config path.
package, cfg_path = _get_package_and_cfg_path(cfg_path)
# Check package is installed.
check_install_package(package)
# Install package if it's not installed.
install_package(package)
package_path = get_installed_path(package)
try:
# Use `cfg_path` to search target config file.
......
......@@ -6,8 +6,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .package_utils import (call_command, check_install_package,
get_installed_path, is_installed)
from .package_utils import (call_command, get_installed_path, install_package,
is_installed)
from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
......@@ -22,9 +22,9 @@ __all__ = [
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_installed', 'call_command', 'get_installed_path',
'check_install_package', 'is_abs', 'is_method_overridden', 'has_method',
'digit_version', 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer',
'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress',
'is_installed', 'call_command', 'get_installed_path', 'install_package',
'is_abs', 'is_method_overridden', 'has_method', 'digit_version',
'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time',
'TimerError', 'ProgressBar', 'track_iter_progress',
'track_parallel_progress', 'track_progress'
]
......@@ -66,6 +66,6 @@ def call_command(cmd: list) -> None:
raise e # type: ignore
def check_install_package(package: str):
def install_package(package: str):
if not is_installed(package):
call_command(['python', '-m', 'pip', 'install', package])
......@@ -435,7 +435,7 @@ class TestConfig:
self._merge_recursive_bases()
self._deprecation()
def test_get_cfg_path(self):
def test_get_cfg_path_local(self):
filename = 'py_config/simple_config.py'
filename = osp.join(self.data_path, 'config', filename)
cfg_name = './base.py'
......@@ -443,13 +443,18 @@ class TestConfig:
assert scope is None
osp.isfile(cfg_path)
# Test scope equal to package name.
@pytest.mark.skipif(
not is_installed('mmdet') or not is_installed('mmcls'),
reason='mmdet and mmcls should be installed')
def test_get_cfg_path_external(self):
filename = 'py_config/simple_config.py'
filename = osp.join(self.data_path, 'config', filename)
cfg_name = 'mmdet::faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
cfg_path, scope = Config._get_cfg_path(cfg_name, filename)
assert scope == 'mmdet'
osp.isfile(cfg_path)
# Test scope does not equal to package name.
cfg_name = 'mmcls::cspnet/cspresnet50_8xb32_in1k.py'
cfg_path, scope = Config._get_cfg_path(cfg_name, filename)
assert scope == 'mmcls'
......@@ -788,6 +793,8 @@ class TestConfig:
assert new_cfg._filename == cfg._filename
assert new_cfg._text == cfg._text
@pytest.mark.skipif(
not is_installed('mmdet'), reason='mmdet should be installed')
def test_get_external_cfg(self):
ext_cfg_path = osp.join(self.data_path,
'config/py_config/test_get_external_cfg.py')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment