From ad1b43faf2a9185e93ad37250aa11cfe5b6be33d Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 30 Dec 2022 14:56:14 +0800 Subject: [PATCH] [Fix] Fix `Config` cannot parse base config when there is `.` in tmp path (#856) * [Fix] Fix config cannot parse tmp path like * Add comments * Add comments * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/config/config.py | 6 +++--- tests/test_config/test_config.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 6a817978..18560ace 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -550,8 +550,8 @@ class Config: Returns: list: A list of base config. """ - file_format = filename.partition('.')[-1] - if file_format == 'py': + file_format = osp.splitext(filename)[1] + if file_format == '.py': Config._validate_py_syntax(filename) with open(filename, encoding='utf-8') as f: codes = ast.parse(f.read()).body @@ -568,7 +568,7 @@ class Config: base_files = eval(compile(base_code, '', mode='eval')) else: base_files = [] - elif file_format in ('yml', 'yaml', 'json'): + elif file_format in ('.yml', '.yaml', '.json'): import mmengine cfg_dict = mmengine.load(filename) base_files = cfg_dict.get(BASE_KEY, []) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 0714734f..8420fe7b 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -5,8 +5,10 @@ import os import os.path as osp import platform import sys +import tempfile from importlib import import_module from pathlib import Path +from unittest.mock import patch import pytest @@ -715,6 +717,21 @@ class TestConfig: cfg = Config._file2dict(cfg_file)[0] assert cfg == dict(item1=dict(a=1)) + # Simulate the case that the temporary directory includes `.`, etc. + # /tmp/test.axsgr12/. This patch is to check the issue + # https://github.com/open-mmlab/mmengine/issues/788 has been solved. + class PatchedTempDirectory(tempfile.TemporaryDirectory): + + def __init__(self, *args, prefix='test.', **kwargs): + super().__init__(*args, prefix=prefix, **kwargs) + + with patch('mmengine.config.config.tempfile.TemporaryDirectory', + PatchedTempDirectory): + cfg_file = osp.join(self.data_path, + 'config/py_config/test_py_modify_key.py') + cfg = Config._file2dict(cfg_file)[0] + assert cfg == dict(item1=dict(a=1)) + def _merge_recursive_bases(self): cfg_file = osp.join(self.data_path, 'config/py_config/test_merge_recursive_bases.py') -- GitLab