diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 8abbf956834379eee7c109ccdee63fafe9400274..05012fe82cfe09850ea1ca951a7831e02d2505ff 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -406,12 +406,20 @@ class Config: base_cfg_dict.update(_cfg_dict) if filename.endswith('.py'): - cfg_dict: dict = dict() with open(temp_config_file.name) as f: codes = ast.parse(f.read()) codes = RemoveAssignFromAST(BASE_KEY).visit(codes) codeobj = compile(codes, '', mode='exec') - eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict) + # Support load global variable in nested function of the + # config. + global_locals_var = {'_base_': base_cfg_dict} + ori_keys = set(global_locals_var.keys()) + eval(codeobj, global_locals_var, global_locals_var) + cfg_dict = { + key: value + for key, value in global_locals_var.items() + if (key not in ori_keys and not key.startswith('__')) + } elif filename.endswith(('.yml', '.yaml', '.json')): cfg_dict = load(temp_config_file.name) # close temp file diff --git a/tests/data/config/py_config/test_py_function_global_var.py b/tests/data/config/py_config/test_py_function_global_var.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5c0953cc9fca81b518591e746c4e655d008736 --- /dev/null +++ b/tests/data/config/py_config/test_py_function_global_var.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +item1 = 1 + + +def get_item2(): + return item1 + 1 + + +item2 = get_item2() diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 6ea0270c559e2e0b9a2948f30dc24a2262a8b2fd..aed3e81772d16e546bc70e0de0b23ff7a8604846 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -651,6 +651,13 @@ class TestConfig: }]], e=[1, 2]) + # Test use global variable in config function + cfg_file = osp.join(self.data_path, + 'config/py_config/test_py_function_global_var.py') + cfg = Config._file2dict(cfg_file)[0] + assert cfg['item1'] == 1 + assert cfg['item2'] == 2 + def _merge_recursive_bases(self): cfg_file = osp.join(self.data_path, 'config/py_config/test_merge_recursive_bases.py')