From 9087956f70acf1e23e6797dd316336e72cae1911 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Fri, 6 May 2022 22:59:54 +0800
Subject: [PATCH] [Enhance] Update config with newest mmcv and show custom
 imports error explicitly (#192)

* add import error information

* Update config with newest mmcv

* add empty line to test config
---
 mmengine/config/config.py                     | 27 ++++++++++++++-----
 mmengine/utils/misc.py                        |  2 +-
 .../py_config/test_dump_pickle_support.py     | 11 ++++++++
 tests/test_config/test_config.py              | 15 ++++++++++-
 4 files changed, 47 insertions(+), 8 deletions(-)

diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index d90d7423..6c367488 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -7,6 +7,7 @@ import platform
 import shutil
 import sys
 import tempfile
+import types
 import uuid
 import warnings
 from argparse import Action, ArgumentParser, Namespace
@@ -167,7 +168,10 @@ class Config:
         cfg_dict, cfg_text = Config._file2dict(filename,
                                                use_predefined_variables)
         if import_custom_modules and cfg_dict.get('custom_imports', None):
-            import_modules_from_strings(**cfg_dict['custom_imports'])
+            try:
+                import_modules_from_strings(**cfg_dict['custom_imports'])
+            except ImportError as e:
+                raise ImportError('Failed to custom import!') from e
         return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
 
     @staticmethod
@@ -396,7 +400,9 @@ class Config:
                 cfg_dict = {
                     name: value
                     for name, value in mod.__dict__.items()
-                    if not name.startswith('__')
+                    if not any((name.startswith('__'),
+                                isinstance(value, types.ModuleType),
+                                isinstance(value, types.FunctionType)))
                 }
                 # delete imported module
                 del sys.modules[temp_module_name]
@@ -409,13 +415,13 @@ class Config:
         if DEPRECATION_KEY in cfg_dict:
             deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
             warning_msg = f'The config file {filename} will be deprecated ' \
-                'in the future.'
+                          'in the future.'
             if 'expected' in deprecation_info:
                 warning_msg += f' Please use {deprecation_info["expected"]} ' \
-                    'instead.'
+                               'instead.'
             if 'reference' in deprecation_info:
                 warning_msg += ' More information can be found at ' \
-                    f'{deprecation_info["reference"]}'
+                               f'{deprecation_info["reference"]}'
             warnings.warn(warning_msg, DeprecationWarning)
 
         cfg_text = filename + '\n'
@@ -558,7 +564,7 @@ class Config:
 
         def _format_basic_types(k, v, use_mapping=False):
             if isinstance(v, str):
-                v_str = f"'{v}'"
+                v_str = repr(v)
             else:
                 v_str = str(v)
 
@@ -673,6 +679,13 @@ class Config:
 
         return other
 
+    def __copy__(self):
+        cls = self.__class__
+        other = cls.__new__(cls)
+        other.__dict__.update(self.__dict__)
+
+        return other
+
     def __setstate__(self, state: Tuple[dict, Optional[str], Optional[str]]):
         _cfg_dict, _filename, _text = state
         super().__setattr__('_cfg_dict', _cfg_dict)
@@ -776,6 +789,8 @@ class DictAction(Action):
             pass
         if val.lower() in ['true', 'false']:
             return True if val.lower() == 'true' else False
+        if val == 'None':
+            return None
         return val
 
     @staticmethod
diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py
index 4df28968..e99c6400 100644
--- a/mmengine/utils/misc.py
+++ b/mmengine/utils/misc.py
@@ -87,7 +87,7 @@ def import_modules_from_strings(imports, allow_failed_imports=False):
                               UserWarning)
                 imported_tmp = None
             else:
-                raise ImportError
+                raise ImportError(f'Failed to import {imp}')
         imported.append(imported_tmp)
     if single_import:
         imported = imported[0]
diff --git a/tests/data/config/py_config/test_dump_pickle_support.py b/tests/data/config/py_config/test_dump_pickle_support.py
index fa7aae26..6050ce10 100644
--- a/tests/data/config/py_config/test_dump_pickle_support.py
+++ b/tests/data/config/py_config/test_dump_pickle_support.py
@@ -1,4 +1,12 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+# config now can have imported modules and defined functions for convenience
+import os.path as osp
+
+
+def func():
+    return 'string with \tescape\\ characters\n'
+
+
 test_item1 = [1, 2]
 bool_item2 = True
 str_item3 = 'test'
@@ -15,3 +23,6 @@ dict_item4 = dict(
     f=dict(a='69'))
 dict_item5 = {'x/x': {'a.0': 233}}
 dict_list_item6 = {'x/x': [{'a.0': 1., 'b.0': 2.}, {'c/3': 3.}]}
+# Test windows path and escape.
+str_item_7 = osp.join(osp.expanduser('~'), 'folder') # with backslash in
+str_item_8 = func()
diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py
index c4cccb2b..f7b0c82e 100644
--- a/tests/test_config/test_config.py
+++ b/tests/test_config/test_config.py
@@ -214,7 +214,8 @@ class TestConfig:
         text_cfg_filename = tmp_path / '_text_config.py'
         cfg.dump(text_cfg_filename)
         text_cfg = Config.fromfile(text_cfg_filename)
-
+        assert text_cfg.str_item_7 == osp.join(osp.expanduser('~'), 'folder')
+        assert text_cfg.str_item_8 == 'string with \tescape\\ characters\n'
         assert text_cfg._cfg_dict == cfg._cfg_dict
 
         cfg_file = osp.join(self.data_path,
@@ -670,3 +671,15 @@ class TestConfig:
         assert new_cfg._cfg_dict is not cfg._cfg_dict
         assert new_cfg._filename == cfg._filename
         assert new_cfg._text == cfg._text
+
+    def test_copy(self):
+        cfg_file = osp.join(self.data_path, 'config',
+                            'py_config/test_dump_pickle_support.py')
+        cfg = Config.fromfile(cfg_file)
+        new_cfg = copy.copy(cfg)
+
+        assert isinstance(new_cfg, Config)
+        assert new_cfg._cfg_dict == cfg._cfg_dict
+        assert new_cfg._cfg_dict is cfg._cfg_dict
+        assert new_cfg._filename == cfg._filename
+        assert new_cfg._text == cfg._text
-- 
GitLab