From 9087956f70acf1e23e6797dd316336e72cae1911 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 6 May 2022 22:59:54 +0800 Subject: [PATCH] [Enhance] Update config with newest mmcv and show custom imports error explicitly (#192) * add import error information * Update config with newest mmcv * add empty line to test config --- mmengine/config/config.py | 27 ++++++++++++++----- mmengine/utils/misc.py | 2 +- .../py_config/test_dump_pickle_support.py | 11 ++++++++ tests/test_config/test_config.py | 15 ++++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/mmengine/config/config.py b/mmengine/config/config.py index d90d7423..6c367488 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -7,6 +7,7 @@ import platform import shutil import sys import tempfile +import types import uuid import warnings from argparse import Action, ArgumentParser, Namespace @@ -167,7 +168,10 @@ class Config: cfg_dict, cfg_text = Config._file2dict(filename, use_predefined_variables) if import_custom_modules and cfg_dict.get('custom_imports', None): - import_modules_from_strings(**cfg_dict['custom_imports']) + try: + import_modules_from_strings(**cfg_dict['custom_imports']) + except ImportError as e: + raise ImportError('Failed to custom import!') from e return Config(cfg_dict, cfg_text=cfg_text, filename=filename) @staticmethod @@ -396,7 +400,9 @@ class Config: cfg_dict = { name: value for name, value in mod.__dict__.items() - if not name.startswith('__') + if not any((name.startswith('__'), + isinstance(value, types.ModuleType), + isinstance(value, types.FunctionType))) } # delete imported module del sys.modules[temp_module_name] @@ -409,13 +415,13 @@ class Config: if DEPRECATION_KEY in cfg_dict: deprecation_info = cfg_dict.pop(DEPRECATION_KEY) warning_msg = f'The config file {filename} will be deprecated ' \ - 'in the future.' + 'in the future.' if 'expected' in deprecation_info: warning_msg += f' Please use {deprecation_info["expected"]} ' \ - 'instead.' + 'instead.' if 'reference' in deprecation_info: warning_msg += ' More information can be found at ' \ - f'{deprecation_info["reference"]}' + f'{deprecation_info["reference"]}' warnings.warn(warning_msg, DeprecationWarning) cfg_text = filename + '\n' @@ -558,7 +564,7 @@ class Config: def _format_basic_types(k, v, use_mapping=False): if isinstance(v, str): - v_str = f"'{v}'" + v_str = repr(v) else: v_str = str(v) @@ -673,6 +679,13 @@ class Config: return other + def __copy__(self): + cls = self.__class__ + other = cls.__new__(cls) + other.__dict__.update(self.__dict__) + + return other + def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]): _cfg_dict, _filename, _text = state super().__setattr__('_cfg_dict', _cfg_dict) @@ -776,6 +789,8 @@ class DictAction(Action): pass if val.lower() in ['true', 'false']: return True if val.lower() == 'true' else False + if val == 'None': + return None return val @staticmethod diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 4df28968..e99c6400 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -87,7 +87,7 @@ def import_modules_from_strings(imports, allow_failed_imports=False): UserWarning) imported_tmp = None else: - raise ImportError + raise ImportError(f'Failed to import {imp}') imported.append(imported_tmp) if single_import: imported = imported[0] diff --git a/tests/data/config/py_config/test_dump_pickle_support.py b/tests/data/config/py_config/test_dump_pickle_support.py index fa7aae26..6050ce10 100644 --- a/tests/data/config/py_config/test_dump_pickle_support.py +++ b/tests/data/config/py_config/test_dump_pickle_support.py @@ -1,4 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +# config now can have imported modules and defined functions for convenience +import os.path as osp + + +def func(): + return 'string with \tescape\\ characters\n' + + test_item1 = [1, 2] bool_item2 = True str_item3 = 'test' @@ -15,3 +23,6 @@ dict_item4 = dict( f=dict(a='69')) dict_item5 = {'x/x': {'a.0': 233}} dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]} +# Test windows path and escape. +str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in +str_item_8 = func() diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index c4cccb2b..f7b0c82e 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -214,7 +214,8 @@ class TestConfig: text_cfg_filename = tmp_path / '_text_config.py' cfg.dump(text_cfg_filename) text_cfg = Config.fromfile(text_cfg_filename) - + assert text_cfg.str_item_7 == osp.join(osp.expanduser('~'), 'folder') + assert text_cfg.str_item_8 == 'string with \tescape\\ characters\n' assert text_cfg._cfg_dict == cfg._cfg_dict cfg_file = osp.join(self.data_path, @@ -670,3 +671,15 @@ class TestConfig: assert new_cfg._cfg_dict is not cfg._cfg_dict assert new_cfg._filename == cfg._filename assert new_cfg._text == cfg._text + + def test_copy(self): + cfg_file = osp.join(self.data_path, 'config', + 'py_config/test_dump_pickle_support.py') + cfg = Config.fromfile(cfg_file) + new_cfg = copy.copy(cfg) + + assert isinstance(new_cfg, Config) + assert new_cfg._cfg_dict == cfg._cfg_dict + assert new_cfg._cfg_dict is cfg._cfg_dict + assert new_cfg._filename == cfg._filename + assert new_cfg._text == cfg._text -- GitLab