From ad1b43faf2a9185e93ad37250aa11cfe5b6be33d Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Fri, 30 Dec 2022 14:56:14 +0800
Subject: [PATCH] [Fix] Fix `Config` cannot parse base config when there is `.`
 in tmp path (#856)

* [Fix] Fix config cannot parse tmp path like

* Add comments

* Add comments

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 mmengine/config/config.py        |  6 +++---
 tests/test_config/test_config.py | 17 +++++++++++++++++
 2 files changed, 20 insertions(+), 3 deletions(-)

diff --git a/mmengine/config/config.py b/mmengine/config/config.py
index 6a817978..18560ace 100644
--- a/mmengine/config/config.py
+++ b/mmengine/config/config.py
@@ -550,8 +550,8 @@ class Config:
         Returns:
             list: A list of base config.
         """
-        file_format = filename.partition('.')[-1]
-        if file_format == 'py':
+        file_format = osp.splitext(filename)[1]
+        if file_format == '.py':
             Config._validate_py_syntax(filename)
             with open(filename, encoding='utf-8') as f:
                 codes = ast.parse(f.read()).body
@@ -568,7 +568,7 @@ class Config:
                     base_files = eval(compile(base_code, '', mode='eval'))
                 else:
                     base_files = []
-        elif file_format in ('yml', 'yaml', 'json'):
+        elif file_format in ('.yml', '.yaml', '.json'):
             import mmengine
             cfg_dict = mmengine.load(filename)
             base_files = cfg_dict.get(BASE_KEY, [])
diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py
index 0714734f..8420fe7b 100644
--- a/tests/test_config/test_config.py
+++ b/tests/test_config/test_config.py
@@ -5,8 +5,10 @@ import os
 import os.path as osp
 import platform
 import sys
+import tempfile
 from importlib import import_module
 from pathlib import Path
+from unittest.mock import patch
 
 import pytest
 
@@ -715,6 +717,21 @@ class TestConfig:
         cfg = Config._file2dict(cfg_file)[0]
         assert cfg == dict(item1=dict(a=1))
 
+        # Simulate the case that the temporary directory includes `.`, etc.
+        # /tmp/test.axsgr12/. This patch is to check the issue
+        # https://github.com/open-mmlab/mmengine/issues/788 has been solved.
+        class PatchedTempDirectory(tempfile.TemporaryDirectory):
+
+            def __init__(self, *args, prefix='test.', **kwargs):
+                super().__init__(*args, prefix=prefix, **kwargs)
+
+        with patch('mmengine.config.config.tempfile.TemporaryDirectory',
+                   PatchedTempDirectory):
+            cfg_file = osp.join(self.data_path,
+                                'config/py_config/test_py_modify_key.py')
+            cfg = Config._file2dict(cfg_file)[0]
+            assert cfg == dict(item1=dict(a=1))
+
     def _merge_recursive_bases(self):
         cfg_file = osp.join(self.data_path,
                             'config/py_config/test_merge_recursive_bases.py')
-- 
GitLab