From 45001a1f6f0eb3cabdb6ef7ef773c92e839b0e13 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Thu, 14 Jul 2022 13:05:55 +0800
Subject: [PATCH] [Enhance] Support using variables in base config directly as
 normal variables. (#329)

* first commit

* Support modify base config and add unit test

* remove import mmengine in config

* add unit test

* fix lint

* add unit test

* move RemoveAssignFromAST to config utils

* git add utils

* fix format issue in test file

* refine unit test

* refine unit test
---
 mmengine/config/config.py                     | 126 +++++++++++-------
 mmengine/config/utils.py                      |  20 +++
 tests/data/config/py_config/test_py_base.py   |  11 ++
 .../config/py_config/test_py_nested_path.py   |  11 ++
 tests/test_config/test_config.py              |  57 +++++++-
 5 files changed, 174 insertions(+), 51 deletions(-)
 create mode 100644 mmengine/config/utils.py
 create mode 100644 tests/data/config/py_config/test_py_base.py
 create mode 100644 tests/data/config/py_config/test_py_nested_path.py

diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index 71f63ab0..8abbf956 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -5,22 +5,21 @@ import os
 import os.path as osp
 import platform
 import shutil
-import sys
 import tempfile
 import types
 import uuid
 import warnings
 from argparse import Action, ArgumentParser, Namespace
 from collections import abc
-from importlib import import_module
 from pathlib import Path
-from typing import Any, Optional, Sequence, Tuple, Union
+from typing import Any, List, Optional, Sequence, Tuple, Union
 
 from addict import Dict
 from yapf.yapflib.yapf_api import FormatCode
 
 from mmengine.fileio import dump, load
 from mmengine.utils import check_file_exist, import_modules_from_strings
+from .utils import RemoveAssignFromAST
 
 BASE_KEY = '_base_'
 DELETE_KEY = '_delete_'
@@ -380,7 +379,7 @@ class Config:
                 dir=temp_config_dir, suffix=fileExtname)
             if platform.system() == 'Windows':
                 temp_config_file.close()
-            temp_config_name = osp.basename(temp_config_file.name)
+
             # Substitute predefined variables
             if use_predefined_variables:
                 Config._substitute_predefined_vars(filename,
@@ -391,37 +390,47 @@ class Config:
             base_var_dict = Config._pre_substitute_base_vars(
                 temp_config_file.name, temp_config_file.name)
 
+            # Handle base files
+            base_cfg_dict = ConfigDict()
+            cfg_text_list = list()
+            for base_cfg_path in Config._parse_base_files(
+                    temp_config_file.name):
+                cfg_dir = osp.dirname(filename)
+                _cfg_dict, _cfg_text = Config._file2dict(
+                    osp.join(cfg_dir, base_cfg_path))
+                cfg_text_list.append(_cfg_text)
+                duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys()
+                if len(duplicate_keys) > 0:
+                    raise KeyError('Duplicate key is not allowed among bases. '
+                                   f'Duplicate keys: {duplicate_keys}')
+                base_cfg_dict.update(_cfg_dict)
+
             if filename.endswith('.py'):
-                temp_module_name = osp.splitext(temp_config_name)[0]
-                sys.path.insert(0, temp_config_dir)
-                Config._validate_py_syntax(filename)
-                mod = import_module(temp_module_name)
-                sys.path.pop(0)
-                cfg_dict = {
-                    name: value
-                    for name, value in mod.__dict__.items()
-                    if not any((name.startswith('__'),
-                                isinstance(value, types.ModuleType),
-                                isinstance(value, types.FunctionType)))
-                }
-                # delete imported module
-                del sys.modules[temp_module_name]
+                cfg_dict: dict = dict()
+                with open(temp_config_file.name) as f:
+                    codes = ast.parse(f.read())
+                    codes = RemoveAssignFromAST(BASE_KEY).visit(codes)
+                codeobj = compile(codes, '', mode='exec')
+                eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict)
             elif filename.endswith(('.yml', '.yaml', '.json')):
                 cfg_dict = load(temp_config_file.name)
             # close temp file
+            for key, value in list(cfg_dict.items()):
+                if isinstance(value, (types.FunctionType, types.ModuleType)):
+                    cfg_dict.pop(key)
             temp_config_file.close()
 
         # check deprecation information
         if DEPRECATION_KEY in cfg_dict:
             deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
             warning_msg = f'The config file {filename} will be deprecated ' \
-                          'in the future.'
+                'in the future.'
             if 'expected' in deprecation_info:
                 warning_msg += f' Please use {deprecation_info["expected"]} ' \
-                               'instead.'
+                    'instead.'
             if 'reference' in deprecation_info:
                 warning_msg += ' More information can be found at ' \
-                               f'{deprecation_info["reference"]}'
+                    f'{deprecation_info["reference"]}'
             warnings.warn(warning_msg, DeprecationWarning)
 
         cfg_text = filename + '\n'
@@ -429,40 +438,59 @@ class Config:
             # Setting encoding explicitly to resolve coding issue on windows
             cfg_text += f.read()
 
-        if BASE_KEY in cfg_dict:
-            cfg_dir = osp.dirname(filename)
-            base_filename = cfg_dict.pop(BASE_KEY)
-            base_filename = base_filename if isinstance(
-                base_filename, list) else [base_filename]
+        # Substitute base variables from strings to their actual values
+        cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+                                                base_cfg_dict)
+        cfg_dict.pop(BASE_KEY, None)
 
-            cfg_dict_list = list()
-            cfg_text_list = list()
-            for f in base_filename:
-                _cfg_dict, _cfg_text = Config._file2dict(
-                    osp.join(cfg_dir, str(f)))
-                cfg_dict_list.append(_cfg_dict)
-                cfg_text_list.append(_cfg_text)
+        cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+        cfg_dict = {
+            k: v
+            for k, v in cfg_dict.items() if not k.startswith('__')
+        }
 
-            base_cfg_dict: dict = dict()
-            for c in cfg_dict_list:
-                duplicate_keys = base_cfg_dict.keys() & c.keys()
-                if len(duplicate_keys) > 0:
-                    raise KeyError('Duplicate key is not allowed among bases. '
-                                   f'Duplicate keys: {duplicate_keys}')
-                base_cfg_dict.update(c)
+        # merge cfg_text
+        cfg_text_list.append(cfg_text)
+        cfg_text = '\n'.join(cfg_text_list)
 
-            # Substitute base variables from strings to their actual values
-            cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
-                                                    base_cfg_dict)
+        return cfg_dict, cfg_text
 
-            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
-            cfg_dict = base_cfg_dict
+    @staticmethod
+    def _parse_base_files(file_path: str) -> List[str]:
+        """Get paths of all base config files.
 
-            # merge cfg_text
-            cfg_text_list.append(cfg_text)
-            cfg_text = '\n'.join(cfg_text_list)
+        Args:
+            file_path (str): Path of config.
 
-        return cfg_dict, cfg_text
+        Returns:
+            List[str]: paths of all base files .
+        """
+        file_format = file_path.partition('.')[-1]
+        if file_format == 'py':
+            Config._validate_py_syntax(file_path)
+            with open(file_path) as f:
+                codes = ast.parse(f.read()).body
+
+                def is_base_line(c):
+                    return (isinstance(c, ast.Assign)
+                            and c.targets[0].id == BASE_KEY)
+
+                base_code = next((c for c in codes if is_base_line(c)), None)
+                if base_code is not None:
+                    base_code = ast.Expression(  # type: ignore
+                        body=base_code.value)  # type: ignore
+                    base_files = eval(compile(base_code, '', mode='eval'))
+                else:
+                    base_files = []
+        elif file_format in ('yml', 'yaml', 'json'):
+            cfg_dict = load(file_path)
+            base_files = cfg_dict.get(BASE_KEY, [])
+        else:
+            raise TypeError('The config type should be py, json, yaml or '
+                            f'yml, but got {file_format}')
+        base_files = base_files if isinstance(base_files,
+                                              list) else [base_files]
+        return base_files
 
     @staticmethod
     def _merge_a_into_b(a: dict,
diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py
new file mode 100644
index 00000000..54c32691
--- /dev/null
+++ b/mmengine/config/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+
+
+class RemoveAssignFromAST(ast.NodeTransformer):
+    """Remove Assign node if the target's name match the key.
+
+    Args:
+        key (str): The target name of the Assign node.
+    """
+
+    def __init__(self, key):
+        self.key = key
+
+    def visit_Assign(self, node):
+        if (isinstance(node.targets[0], ast.Name)
+                and node.targets[0].id == self.key):
+            return None
+        else:
+            return node
diff --git a/tests/data/config/py_config/test_py_base.py b/tests/data/config/py_config/test_py_base.py
new file mode 100644
index 00000000..80737057
--- /dev/null
+++ b/tests/data/config/py_config/test_py_base.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = [
+    './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json',
+    './base4.py'
+]
+item2 = dict(b=[5, 6])
+item3 = False
+item4 = 'test'
+_base_.item6[0] = dict(c=0)
+item8 = '{{fileBasename}}'
+item9, item10, item11 = _base_.item7['b']['c']
diff --git a/tests/data/config/py_config/test_py_nested_path.py b/tests/data/config/py_config/test_py_nested_path.py
new file mode 100644
index 00000000..b233616b
--- /dev/null
+++ b/tests/data/config/py_config/test_py_nested_path.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+_base_ = ['./test_py_base.py']
+item12 = _base_.item8
+item13 = _base_.item9
+item14 = _base_.item1
+item15 = dict(
+    a=dict(b=_base_.item2),
+    b=[_base_.item3],
+    c=[_base_.item4],
+    d=[[dict(e=_base_.item5['a'])], _base_.item6],
+    e=_base_.item1)
diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py
index 5af2a80c..6ea0270c 100644
--- a/tests/test_config/test_config.py
+++ b/tests/test_config/test_config.py
@@ -598,6 +598,59 @@ class TestConfig:
             }]],
             e='test_base_variables.py')
 
+        cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py')
+        cfg = Config.fromfile(cfg_file)
+        assert isinstance(cfg, Config)
+        assert cfg.filename == cfg_file
+        # cfg.field
+        assert cfg.item1 == [1, 2]
+        assert cfg.item2.a == 0
+        assert cfg.item2.b == [5, 6]
+        assert cfg.item3 is False
+        assert cfg.item4 == 'test'
+        assert cfg.item5 == dict(a=0, b=1)
+        assert cfg.item6 == [dict(c=0), dict(b=1)]
+        assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
+        assert cfg.item8 == 'test_py_base.py'
+        assert cfg.item9 == 3.1
+        assert cfg.item10 == 4.2
+        assert cfg.item11 == 5.3
+
+        # test nested base
+        cfg_file = osp.join(self.data_path,
+                            'config/py_config/test_py_nested_path.py')
+        cfg = Config.fromfile(cfg_file)
+        assert isinstance(cfg, Config)
+        assert cfg.filename == cfg_file
+        # cfg.field
+        assert cfg.item1 == [1, 2]
+        assert cfg.item2.a == 0
+        assert cfg.item2.b == [5, 6]
+        assert cfg.item3 is False
+        assert cfg.item4 == 'test'
+        assert cfg.item5 == dict(a=0, b=1)
+        assert cfg.item6 == [dict(c=0), dict(b=1)]
+        assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
+        assert cfg.item8 == 'test_py_base.py'
+        assert cfg.item9 == 3.1
+        assert cfg.item10 == 4.2
+        assert cfg.item11 == 5.3
+        assert cfg.item12 == 'test_py_base.py'
+        assert cfg.item13 == 3.1
+        assert cfg.item14 == [1, 2]
+        assert cfg.item15 == dict(
+            a=dict(b=dict(a=0, b=[5, 6])),
+            b=[False],
+            c=['test'],
+            d=[[{
+                'e': 0
+            }], [{
+                'c': 0
+            }, {
+                'b': 1
+            }]],
+            e=[1, 2])
+
     def _merge_recursive_bases(self):
         cfg_file = osp.join(self.data_path,
                             'config/py_config/test_merge_recursive_bases.py')
@@ -617,10 +670,10 @@ class TestConfig:
         assert cfg_dict['item2'] == dict(a=0, b=0)
         assert cfg_dict['item3'] is True
         assert cfg_dict['item4'] == 'test'
-        assert '_delete_' not in cfg_dict['item2']
+        assert '_delete_' not in cfg_dict['item1']
 
         assert type(cfg_dict['item1']) == ConfigDict
-        assert type(cfg_dict['item2']) == dict
+        assert type(cfg_dict['item2']) == ConfigDict
 
     def _merge_intermediate_variable(self):
 
-- 
GitLab