From f850de71c35b07a5966a25ecd312a54c2320ed4b Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 1 Aug 2022 20:10:10 +0800 Subject: [PATCH] [Fix] Support use 'global variable' in config function (#390) * Support use 'global var' in config function * upload test file --- mmengine/config/config.py | 12 ++++++++++-- .../config/py_config/test_py_function_global_var.py | 9 +++++++++ tests/test_config/test_config.py | 7 +++++++ 3 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 tests/data/config/py_config/test_py_function_global_var.py diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 8abbf956..05012fe8 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -406,12 +406,20 @@ class Config: base_cfg_dict.update(_cfg_dict) if filename.endswith('.py'): - cfg_dict: dict = dict() with open(temp_config_file.name) as f: codes = ast.parse(f.read()) codes = RemoveAssignFromAST(BASE_KEY).visit(codes) codeobj = compile(codes, '', mode='exec') - eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict) + # Support load global variable in nested function of the + # config. + global_locals_var = {'_base_': base_cfg_dict} + ori_keys = set(global_locals_var.keys()) + eval(codeobj, global_locals_var, global_locals_var) + cfg_dict = { + key: value + for key, value in global_locals_var.items() + if (key not in ori_keys and not key.startswith('__')) + } elif filename.endswith(('.yml', '.yaml', '.json')): cfg_dict = load(temp_config_file.name) # close temp file diff --git a/tests/data/config/py_config/test_py_function_global_var.py b/tests/data/config/py_config/test_py_function_global_var.py new file mode 100644 index 00000000..8a5c0953 --- /dev/null +++ b/tests/data/config/py_config/test_py_function_global_var.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +item1 = 1 + + +def get_item2(): + return item1 + 1 + + +item2 = get_item2() diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 6ea0270c..aed3e817 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -651,6 +651,13 @@ class TestConfig: }]], e=[1, 2]) + # Test use global variable in config function + cfg_file = osp.join(self.data_path, + 'config/py_config/test_py_function_global_var.py') + cfg = Config._file2dict(cfg_file)[0] + assert cfg['item1'] == 1 + assert cfg['item2'] == 2 + def _merge_recursive_bases(self): cfg_file = osp.join(self.data_path, 'config/py_config/test_merge_recursive_bases.py') -- GitLab