diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 71f63ab0c59c65ce9c1015f40617fff6df031f9c..8abbf956834379eee7c109ccdee63fafe9400274 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -5,22 +5,21 @@ import os import os.path as osp import platform import shutil -import sys import tempfile import types import uuid import warnings from argparse import Action, ArgumentParser, Namespace from collections import abc -from importlib import import_module from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union from addict import Dict from yapf.yapflib.yapf_api import FormatCode from mmengine.fileio import dump, load from mmengine.utils import check_file_exist, import_modules_from_strings +from .utils import RemoveAssignFromAST BASE_KEY = '_base_' DELETE_KEY = '_delete_' @@ -380,7 +379,7 @@ class Config: dir=temp_config_dir, suffix=fileExtname) if platform.system() == 'Windows': temp_config_file.close() - temp_config_name = osp.basename(temp_config_file.name) + # Substitute predefined variables if use_predefined_variables: Config._substitute_predefined_vars(filename, @@ -391,37 +390,47 @@ class Config: base_var_dict = Config._pre_substitute_base_vars( temp_config_file.name, temp_config_file.name) + # Handle base files + base_cfg_dict = ConfigDict() + cfg_text_list = list() + for base_cfg_path in Config._parse_base_files( + temp_config_file.name): + cfg_dir = osp.dirname(filename) + _cfg_dict, _cfg_text = Config._file2dict( + osp.join(cfg_dir, base_cfg_path)) + cfg_text_list.append(_cfg_text) + duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() + if len(duplicate_keys) > 0: + raise KeyError('Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + base_cfg_dict.update(_cfg_dict) + if filename.endswith('.py'): - temp_module_name = osp.splitext(temp_config_name)[0] - sys.path.insert(0, temp_config_dir) - Config._validate_py_syntax(filename) - mod = import_module(temp_module_name) - sys.path.pop(0) - cfg_dict = { - name: value - for name, value in mod.__dict__.items() - if not any((name.startswith('__'), - isinstance(value, types.ModuleType), - isinstance(value, types.FunctionType))) - } - # delete imported module - del sys.modules[temp_module_name] + cfg_dict: dict = dict() + with open(temp_config_file.name) as f: + codes = ast.parse(f.read()) + codes = RemoveAssignFromAST(BASE_KEY).visit(codes) + codeobj = compile(codes, '', mode='exec') + eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict) elif filename.endswith(('.yml', '.yaml', '.json')): cfg_dict = load(temp_config_file.name) # close temp file + for key, value in list(cfg_dict.items()): + if isinstance(value, (types.FunctionType, types.ModuleType)): + cfg_dict.pop(key) temp_config_file.close() # check deprecation information if DEPRECATION_KEY in cfg_dict: deprecation_info = cfg_dict.pop(DEPRECATION_KEY) warning_msg = f'The config file {filename} will be deprecated ' \ - 'in the future.' + 'in the future.' if 'expected' in deprecation_info: warning_msg += f' Please use {deprecation_info["expected"]} ' \ - 'instead.' + 'instead.' if 'reference' in deprecation_info: warning_msg += ' More information can be found at ' \ - f'{deprecation_info["reference"]}' + f'{deprecation_info["reference"]}' warnings.warn(warning_msg, DeprecationWarning) cfg_text = filename + '\n' @@ -429,40 +438,59 @@ class Config: # Setting encoding explicitly to resolve coding issue on windows cfg_text += f.read() - if BASE_KEY in cfg_dict: - cfg_dir = osp.dirname(filename) - base_filename = cfg_dict.pop(BASE_KEY) - base_filename = base_filename if isinstance( - base_filename, list) else [base_filename] + # Substitute base variables from strings to their actual values + cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, + base_cfg_dict) + cfg_dict.pop(BASE_KEY, None) - cfg_dict_list = list() - cfg_text_list = list() - for f in base_filename: - _cfg_dict, _cfg_text = Config._file2dict( - osp.join(cfg_dir, str(f))) - cfg_dict_list.append(_cfg_dict) - cfg_text_list.append(_cfg_text) + cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) + cfg_dict = { + k: v + for k, v in cfg_dict.items() if not k.startswith('__') + } - base_cfg_dict: dict = dict() - for c in cfg_dict_list: - duplicate_keys = base_cfg_dict.keys() & c.keys() - if len(duplicate_keys) > 0: - raise KeyError('Duplicate key is not allowed among bases. ' - f'Duplicate keys: {duplicate_keys}') - base_cfg_dict.update(c) + # merge cfg_text + cfg_text_list.append(cfg_text) + cfg_text = '\n'.join(cfg_text_list) - # Substitute base variables from strings to their actual values - cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, - base_cfg_dict) + return cfg_dict, cfg_text - base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) - cfg_dict = base_cfg_dict + @staticmethod + def _parse_base_files(file_path: str) -> List[str]: + """Get paths of all base config files. - # merge cfg_text - cfg_text_list.append(cfg_text) - cfg_text = '\n'.join(cfg_text_list) + Args: + file_path (str): Path of config. - return cfg_dict, cfg_text + Returns: + List[str]: paths of all base files . + """ + file_format = file_path.partition('.')[-1] + if file_format == 'py': + Config._validate_py_syntax(file_path) + with open(file_path) as f: + codes = ast.parse(f.read()).body + + def is_base_line(c): + return (isinstance(c, ast.Assign) + and c.targets[0].id == BASE_KEY) + + base_code = next((c for c in codes if is_base_line(c)), None) + if base_code is not None: + base_code = ast.Expression( # type: ignore + body=base_code.value) # type: ignore + base_files = eval(compile(base_code, '', mode='eval')) + else: + base_files = [] + elif file_format in ('yml', 'yaml', 'json'): + cfg_dict = load(file_path) + base_files = cfg_dict.get(BASE_KEY, []) + else: + raise TypeError('The config type should be py, json, yaml or ' + f'yml, but got {file_format}') + base_files = base_files if isinstance(base_files, + list) else [base_files] + return base_files @staticmethod def _merge_a_into_b(a: dict, diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54c32691f71a2cca051eaa8b65f2cf6aab143657 --- /dev/null +++ b/mmengine/config/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast + + +class RemoveAssignFromAST(ast.NodeTransformer): + """Remove Assign node if the target's name match the key. + + Args: + key (str): The target name of the Assign node. + """ + + def __init__(self, key): + self.key = key + + def visit_Assign(self, node): + if (isinstance(node.targets[0], ast.Name) + and node.targets[0].id == self.key): + return None + else: + return node diff --git a/tests/data/config/py_config/test_py_base.py b/tests/data/config/py_config/test_py_base.py new file mode 100644 index 0000000000000000000000000000000000000000..8073705726db5f0a706b21ec62201eb0c8040451 --- /dev/null +++ b/tests/data/config/py_config/test_py_base.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = [ + './base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json', + './base4.py' +] +item2 = dict(b=[5, 6]) +item3 = False +item4 = 'test' +_base_.item6[0] = dict(c=0) +item8 = '{{fileBasename}}' +item9, item10, item11 = _base_.item7['b']['c'] diff --git a/tests/data/config/py_config/test_py_nested_path.py b/tests/data/config/py_config/test_py_nested_path.py new file mode 100644 index 0000000000000000000000000000000000000000..b233616bd4879501a48d2c5ffab27e9939f09fa9 --- /dev/null +++ b/tests/data/config/py_config/test_py_nested_path.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = ['./test_py_base.py'] +item12 = _base_.item8 +item13 = _base_.item9 +item14 = _base_.item1 +item15 = dict( + a=dict(b=_base_.item2), + b=[_base_.item3], + c=[_base_.item4], + d=[[dict(e=_base_.item5['a'])], _base_.item6], + e=_base_.item1) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 5af2a80c4f5238b726dda135349515cd6f549617..6ea0270c559e2e0b9a2948f30dc24a2262a8b2fd 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -598,6 +598,59 @@ class TestConfig: }]], e='test_base_variables.py') + cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py') + cfg = Config.fromfile(cfg_file) + assert isinstance(cfg, Config) + assert cfg.filename == cfg_file + # cfg.field + assert cfg.item1 == [1, 2] + assert cfg.item2.a == 0 + assert cfg.item2.b == [5, 6] + assert cfg.item3 is False + assert cfg.item4 == 'test' + assert cfg.item5 == dict(a=0, b=1) + assert cfg.item6 == [dict(c=0), dict(b=1)] + assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg.item8 == 'test_py_base.py' + assert cfg.item9 == 3.1 + assert cfg.item10 == 4.2 + assert cfg.item11 == 5.3 + + # test nested base + cfg_file = osp.join(self.data_path, + 'config/py_config/test_py_nested_path.py') + cfg = Config.fromfile(cfg_file) + assert isinstance(cfg, Config) + assert cfg.filename == cfg_file + # cfg.field + assert cfg.item1 == [1, 2] + assert cfg.item2.a == 0 + assert cfg.item2.b == [5, 6] + assert cfg.item3 is False + assert cfg.item4 == 'test' + assert cfg.item5 == dict(a=0, b=1) + assert cfg.item6 == [dict(c=0), dict(b=1)] + assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg.item8 == 'test_py_base.py' + assert cfg.item9 == 3.1 + assert cfg.item10 == 4.2 + assert cfg.item11 == 5.3 + assert cfg.item12 == 'test_py_base.py' + assert cfg.item13 == 3.1 + assert cfg.item14 == [1, 2] + assert cfg.item15 == dict( + a=dict(b=dict(a=0, b=[5, 6])), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'c': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) + def _merge_recursive_bases(self): cfg_file = osp.join(self.data_path, 'config/py_config/test_merge_recursive_bases.py') @@ -617,10 +670,10 @@ class TestConfig: assert cfg_dict['item2'] == dict(a=0, b=0) assert cfg_dict['item3'] is True assert cfg_dict['item4'] == 'test' - assert '_delete_' not in cfg_dict['item2'] + assert '_delete_' not in cfg_dict['item1'] assert type(cfg_dict['item1']) == ConfigDict - assert type(cfg_dict['item2']) == dict + assert type(cfg_dict['item2']) == ConfigDict def _merge_intermediate_variable(self):