Skip to content
Snippets Groups Projects
Unverified Commit 9087956f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] Update config with newest mmcv and show custom imports error explicitly (#192)

* add import error information

* Update config with newest mmcv

* add empty line to test config
parent 859f4d15
No related branches found
No related tags found
No related merge requests found
...@@ -7,6 +7,7 @@ import platform ...@@ -7,6 +7,7 @@ import platform
import shutil import shutil
import sys import sys
import tempfile import tempfile
import types
import uuid import uuid
import warnings import warnings
from argparse import Action, ArgumentParser, Namespace from argparse import Action, ArgumentParser, Namespace
...@@ -167,7 +168,10 @@ class Config: ...@@ -167,7 +168,10 @@ class Config:
cfg_dict, cfg_text = Config._file2dict(filename, cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables) use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None): if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports']) try:
import_modules_from_strings(**cfg_dict['custom_imports'])
except ImportError as e:
raise ImportError('Failed to custom import!') from e
return Config(cfg_dict, cfg_text=cfg_text, filename=filename) return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod @staticmethod
...@@ -396,7 +400,9 @@ class Config: ...@@ -396,7 +400,9 @@ class Config:
cfg_dict = { cfg_dict = {
name: value name: value
for name, value in mod.__dict__.items() for name, value in mod.__dict__.items()
if not name.startswith('__') if not any((name.startswith('__'),
isinstance(value, types.ModuleType),
isinstance(value, types.FunctionType)))
} }
# delete imported module # delete imported module
del sys.modules[temp_module_name] del sys.modules[temp_module_name]
...@@ -409,13 +415,13 @@ class Config: ...@@ -409,13 +415,13 @@ class Config:
if DEPRECATION_KEY in cfg_dict: if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY) deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = f'The config file {filename} will be deprecated ' \ warning_msg = f'The config file {filename} will be deprecated ' \
'in the future.' 'in the future.'
if 'expected' in deprecation_info: if 'expected' in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' \ warning_msg += f' Please use {deprecation_info["expected"]} ' \
'instead.' 'instead.'
if 'reference' in deprecation_info: if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \ warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}' f'{deprecation_info["reference"]}'
warnings.warn(warning_msg, DeprecationWarning) warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n' cfg_text = filename + '\n'
...@@ -558,7 +564,7 @@ class Config: ...@@ -558,7 +564,7 @@ class Config:
def _format_basic_types(k, v, use_mapping=False): def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str): if isinstance(v, str):
v_str = f"'{v}'" v_str = repr(v)
else: else:
v_str = str(v) v_str = str(v)
...@@ -673,6 +679,13 @@ class Config: ...@@ -673,6 +679,13 @@ class Config:
return other return other
def __copy__(self):
cls = self.__class__
other = cls.__new__(cls)
other.__dict__.update(self.__dict__)
return other
def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]): def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]):
_cfg_dict, _filename, _text = state _cfg_dict, _filename, _text = state
super().__setattr__('_cfg_dict', _cfg_dict) super().__setattr__('_cfg_dict', _cfg_dict)
...@@ -776,6 +789,8 @@ class DictAction(Action): ...@@ -776,6 +789,8 @@ class DictAction(Action):
pass pass
if val.lower() in ['true', 'false']: if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False return True if val.lower() == 'true' else False
if val == 'None':
return None
return val return val
@staticmethod @staticmethod
......
...@@ -87,7 +87,7 @@ def import_modules_from_strings(imports, allow_failed_imports=False): ...@@ -87,7 +87,7 @@ def import_modules_from_strings(imports, allow_failed_imports=False):
UserWarning) UserWarning)
imported_tmp = None imported_tmp = None
else: else:
raise ImportError raise ImportError(f'Failed to import {imp}')
imported.append(imported_tmp) imported.append(imported_tmp)
if single_import: if single_import:
imported = imported[0] imported = imported[0]
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# config now can have imported modules and defined functions for convenience
import os.path as osp
def func():
return 'string with \tescape\\ characters\n'
test_item1 = [1, 2] test_item1 = [1, 2]
bool_item2 = True bool_item2 = True
str_item3 = 'test' str_item3 = 'test'
...@@ -15,3 +23,6 @@ dict_item4 = dict( ...@@ -15,3 +23,6 @@ dict_item4 = dict(
f=dict(a='69')) f=dict(a='69'))
dict_item5 = {'x/x': {'a.0': 233}} dict_item5 = {'x/x': {'a.0': 233}}
dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]} dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]}
# Test windows path and escape.
str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in
str_item_8 = func()
...@@ -214,7 +214,8 @@ class TestConfig: ...@@ -214,7 +214,8 @@ class TestConfig:
text_cfg_filename = tmp_path / '_text_config.py' text_cfg_filename = tmp_path / '_text_config.py'
cfg.dump(text_cfg_filename) cfg.dump(text_cfg_filename)
text_cfg = Config.fromfile(text_cfg_filename) text_cfg = Config.fromfile(text_cfg_filename)
assert text_cfg.str_item_7 == osp.join(osp.expanduser('~'), 'folder')
assert text_cfg.str_item_8 == 'string with \tescape\\ characters\n'
assert text_cfg._cfg_dict == cfg._cfg_dict assert text_cfg._cfg_dict == cfg._cfg_dict
cfg_file = osp.join(self.data_path, cfg_file = osp.join(self.data_path,
...@@ -670,3 +671,15 @@ class TestConfig: ...@@ -670,3 +671,15 @@ class TestConfig:
assert new_cfg._cfg_dict is not cfg._cfg_dict assert new_cfg._cfg_dict is not cfg._cfg_dict
assert new_cfg._filename == cfg._filename assert new_cfg._filename == cfg._filename
assert new_cfg._text == cfg._text assert new_cfg._text == cfg._text
def test_copy(self):
cfg_file = osp.join(self.data_path, 'config',
'py_config/test_dump_pickle_support.py')
cfg = Config.fromfile(cfg_file)
new_cfg = copy.copy(cfg)
assert isinstance(new_cfg, Config)
assert new_cfg._cfg_dict == cfg._cfg_dict
assert new_cfg._cfg_dict is cfg._cfg_dict
assert new_cfg._filename == cfg._filename
assert new_cfg._text == cfg._text
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment