Skip to content
Snippets Groups Projects
Unverified Commit 45001a1f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhance] Support using variables in base config directly as normal variables. (#329)

* first commit

* Support modify base config and add unit test

* remove import mmengine in config

* add unit test

* fix lint

* add unit test

* move RemoveAssignFromAST to config utils

* git add utils

* fix format issue in test file

* refine unit test

* refine unit test
parent 6b608b4e
No related branches found
No related tags found
No related merge requests found
......@@ -5,22 +5,21 @@ import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import types
import uuid
import warnings
from argparse import Action, ArgumentParser, Namespace
from collections import abc
from importlib import import_module
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from typing import Any, List, Optional, Sequence, Tuple, Union
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
from mmengine.fileio import dump, load
from mmengine.utils import check_file_exist, import_modules_from_strings
from .utils import RemoveAssignFromAST
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
......@@ -380,7 +379,7 @@ class Config:
dir=temp_config_dir, suffix=fileExtname)
if platform.system() == 'Windows':
temp_config_file.close()
temp_config_name = osp.basename(temp_config_file.name)
# Substitute predefined variables
if use_predefined_variables:
Config._substitute_predefined_vars(filename,
......@@ -391,37 +390,47 @@ class Config:
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)
# Handle base files
base_cfg_dict = ConfigDict()
cfg_text_list = list()
for base_cfg_path in Config._parse_base_files(
temp_config_file.name):
cfg_dir = osp.dirname(filename)
_cfg_dict, _cfg_text = Config._file2dict(
osp.join(cfg_dir, base_cfg_path))
cfg_text_list.append(_cfg_text)
duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(_cfg_dict)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
Config._validate_py_syntax(filename)
mod = import_module(temp_module_name)
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not any((name.startswith('__'),
isinstance(value, types.ModuleType),
isinstance(value, types.FunctionType)))
}
# delete imported module
del sys.modules[temp_module_name]
cfg_dict: dict = dict()
with open(temp_config_file.name) as f:
codes = ast.parse(f.read())
codes = RemoveAssignFromAST(BASE_KEY).visit(codes)
codeobj = compile(codes, '', mode='exec')
eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict)
elif filename.endswith(('.yml', '.yaml', '.json')):
cfg_dict = load(temp_config_file.name)
# close temp file
for key, value in list(cfg_dict.items()):
if isinstance(value, (types.FunctionType, types.ModuleType)):
cfg_dict.pop(key)
temp_config_file.close()
# check deprecation information
if DEPRECATION_KEY in cfg_dict:
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
warning_msg = f'The config file {filename} will be deprecated ' \
'in the future.'
'in the future.'
if 'expected' in deprecation_info:
warning_msg += f' Please use {deprecation_info["expected"]} ' \
'instead.'
'instead.'
if 'reference' in deprecation_info:
warning_msg += ' More information can be found at ' \
f'{deprecation_info["reference"]}'
f'{deprecation_info["reference"]}'
warnings.warn(warning_msg, DeprecationWarning)
cfg_text = filename + '\n'
......@@ -429,40 +438,59 @@ class Config:
# Setting encoding explicitly to resolve coding issue on windows
cfg_text += f.read()
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
cfg_dict.pop(BASE_KEY, None)
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(
osp.join(cfg_dir, str(f)))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = {
k: v
for k, v in cfg_dict.items() if not k.startswith('__')
}
base_cfg_dict: dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError('Duplicate key is not allowed among bases. '
f'Duplicate keys: {duplicate_keys}')
base_cfg_dict.update(c)
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
# Substitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
return cfg_dict, cfg_text
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
@staticmethod
def _parse_base_files(file_path: str) -> List[str]:
"""Get paths of all base config files.
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
Args:
file_path (str): Path of config.
return cfg_dict, cfg_text
Returns:
List[str]: paths of all base files .
"""
file_format = file_path.partition('.')[-1]
if file_format == 'py':
Config._validate_py_syntax(file_path)
with open(file_path) as f:
codes = ast.parse(f.read()).body
def is_base_line(c):
return (isinstance(c, ast.Assign)
and c.targets[0].id == BASE_KEY)
base_code = next((c for c in codes if is_base_line(c)), None)
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
else:
base_files = []
elif file_format in ('yml', 'yaml', 'json'):
cfg_dict = load(file_path)
base_files = cfg_dict.get(BASE_KEY, [])
else:
raise TypeError('The config type should be py, json, yaml or '
f'yml, but got {file_format}')
base_files = base_files if isinstance(base_files,
list) else [base_files]
return base_files
@staticmethod
def _merge_a_into_b(a: dict,
......
# Copyright (c) OpenMMLab. All rights reserved.
import ast
class RemoveAssignFromAST(ast.NodeTransformer):
"""Remove Assign node if the target's name match the key.
Args:
key (str): The target name of the Assign node.
"""
def __init__(self, key):
self.key = key
def visit_Assign(self, node):
if (isinstance(node.targets[0], ast.Name)
and node.targets[0].id == self.key):
return None
else:
return node
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = [
'./base1.py', '../yaml_config/base2.yaml', '../json_config/base3.json',
'./base4.py'
]
item2 = dict(b=[5, 6])
item3 = False
item4 = 'test'
_base_.item6[0] = dict(c=0)
item8 = '{{fileBasename}}'
item9, item10, item11 = _base_.item7['b']['c']
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['./test_py_base.py']
item12 = _base_.item8
item13 = _base_.item9
item14 = _base_.item1
item15 = dict(
a=dict(b=_base_.item2),
b=[_base_.item3],
c=[_base_.item4],
d=[[dict(e=_base_.item5['a'])], _base_.item6],
e=_base_.item1)
......@@ -598,6 +598,59 @@ class TestConfig:
}]],
e='test_base_variables.py')
cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item2.b == [5, 6]
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(c=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 'test_py_base.py'
assert cfg.item9 == 3.1
assert cfg.item10 == 4.2
assert cfg.item11 == 5.3
# test nested base
cfg_file = osp.join(self.data_path,
'config/py_config/test_py_nested_path.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item2.b == [5, 6]
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(c=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 'test_py_base.py'
assert cfg.item9 == 3.1
assert cfg.item10 == 4.2
assert cfg.item11 == 5.3
assert cfg.item12 == 'test_py_base.py'
assert cfg.item13 == 3.1
assert cfg.item14 == [1, 2]
assert cfg.item15 == dict(
a=dict(b=dict(a=0, b=[5, 6])),
b=[False],
c=['test'],
d=[[{
'e': 0
}], [{
'c': 0
}, {
'b': 1
}]],
e=[1, 2])
def _merge_recursive_bases(self):
cfg_file = osp.join(self.data_path,
'config/py_config/test_merge_recursive_bases.py')
......@@ -617,10 +670,10 @@ class TestConfig:
assert cfg_dict['item2'] == dict(a=0, b=0)
assert cfg_dict['item3'] is True
assert cfg_dict['item4'] == 'test'
assert '_delete_' not in cfg_dict['item2']
assert '_delete_' not in cfg_dict['item1']
assert type(cfg_dict['item1']) == ConfigDict
assert type(cfg_dict['item2']) == dict
assert type(cfg_dict['item2']) == ConfigDict
def _merge_intermediate_variable(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment