Skip to content
Snippets Groups Projects
Unverified Commit f850de71 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Support use 'global variable' in config function (#390)

* Support use 'global var' in config function

* upload test file
parent 81c3de54
No related branches found
No related tags found
No related merge requests found
...@@ -406,12 +406,20 @@ class Config: ...@@ -406,12 +406,20 @@ class Config:
base_cfg_dict.update(_cfg_dict) base_cfg_dict.update(_cfg_dict)
if filename.endswith('.py'): if filename.endswith('.py'):
cfg_dict: dict = dict()
with open(temp_config_file.name) as f: with open(temp_config_file.name) as f:
codes = ast.parse(f.read()) codes = ast.parse(f.read())
codes = RemoveAssignFromAST(BASE_KEY).visit(codes) codes = RemoveAssignFromAST(BASE_KEY).visit(codes)
codeobj = compile(codes, '', mode='exec') codeobj = compile(codes, '', mode='exec')
eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict) # Support load global variable in nested function of the
# config.
global_locals_var = {'_base_': base_cfg_dict}
ori_keys = set(global_locals_var.keys())
eval(codeobj, global_locals_var, global_locals_var)
cfg_dict = {
key: value
for key, value in global_locals_var.items()
if (key not in ori_keys and not key.startswith('__'))
}
elif filename.endswith(('.yml', '.yaml', '.json')): elif filename.endswith(('.yml', '.yaml', '.json')):
cfg_dict = load(temp_config_file.name) cfg_dict = load(temp_config_file.name)
# close temp file # close temp file
......
# Copyright (c) OpenMMLab. All rights reserved.
item1 = 1
def get_item2():
return item1 + 1
item2 = get_item2()
...@@ -651,6 +651,13 @@ class TestConfig: ...@@ -651,6 +651,13 @@ class TestConfig:
}]], }]],
e=[1, 2]) e=[1, 2])
# Test use global variable in config function
cfg_file = osp.join(self.data_path,
'config/py_config/test_py_function_global_var.py')
cfg = Config._file2dict(cfg_file)[0]
assert cfg['item1'] == 1
assert cfg['item2'] == 2
def _merge_recursive_bases(self): def _merge_recursive_bases(self):
cfg_file = osp.join(self.data_path, cfg_file = osp.join(self.data_path,
'config/py_config/test_merge_recursive_bases.py') 'config/py_config/test_merge_recursive_bases.py')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment