From f850de71c35b07a5966a25ecd312a54c2320ed4b Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Mon, 1 Aug 2022 20:10:10 +0800
Subject: [PATCH] [Fix] Support use 'global variable' in config function (#390)

* Support use 'global var' in config function

* upload test file
---
 mmengine/config/config.py                            | 12 ++++++++++--
 .../config/py_config/test_py_function_global_var.py  |  9 +++++++++
 tests/test_config/test_config.py                     |  7 +++++++
 3 files changed, 26 insertions(+), 2 deletions(-)
 create mode 100644 tests/data/config/py_config/test_py_function_global_var.py

diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index 8abbf956..05012fe8 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -406,12 +406,20 @@ class Config:
                 base_cfg_dict.update(_cfg_dict)
 
             if filename.endswith('.py'):
-                cfg_dict: dict = dict()
                 with open(temp_config_file.name) as f:
                     codes = ast.parse(f.read())
                     codes = RemoveAssignFromAST(BASE_KEY).visit(codes)
                 codeobj = compile(codes, '', mode='exec')
-                eval(codeobj, {'_base_': base_cfg_dict}, cfg_dict)
+                # Support load global variable in nested function of the
+                # config.
+                global_locals_var = {'_base_': base_cfg_dict}
+                ori_keys = set(global_locals_var.keys())
+                eval(codeobj, global_locals_var, global_locals_var)
+                cfg_dict = {
+                    key: value
+                    for key, value in global_locals_var.items()
+                    if (key not in ori_keys and not key.startswith('__'))
+                }
             elif filename.endswith(('.yml', '.yaml', '.json')):
                 cfg_dict = load(temp_config_file.name)
             # close temp file
diff --git a/tests/data/config/py_config/test_py_function_global_var.py b/tests/data/config/py_config/test_py_function_global_var.py
new file mode 100644
index 00000000..8a5c0953
--- /dev/null
+++ b/tests/data/config/py_config/test_py_function_global_var.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+item1 = 1
+
+
+def get_item2():
+    return item1 + 1
+
+
+item2 = get_item2()
diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py
index 6ea0270c..aed3e817 100644
--- a/tests/test_config/test_config.py
+++ b/tests/test_config/test_config.py
@@ -651,6 +651,13 @@ class TestConfig:
             }]],
             e=[1, 2])
 
+        # Test use global variable in config function
+        cfg_file = osp.join(self.data_path,
+                            'config/py_config/test_py_function_global_var.py')
+        cfg = Config._file2dict(cfg_file)[0]
+        assert cfg['item1'] == 1
+        assert cfg['item2'] == 2
+
     def _merge_recursive_bases(self):
         cfg_file = osp.join(self.data_path,
                             'config/py_config/test_merge_recursive_bases.py')
-- 
GitLab