diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 6a81797880341be098abc5c9f1bcc06e4db78a43..18560aceccc6a6748327b936ed43f38abf175ae1 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -550,8 +550,8 @@ class Config: Returns: list: A list of base config. """ - file_format = filename.partition('.')[-1] - if file_format == 'py': + file_format = osp.splitext(filename)[1] + if file_format == '.py': Config._validate_py_syntax(filename) with open(filename, encoding='utf-8') as f: codes = ast.parse(f.read()).body @@ -568,7 +568,7 @@ class Config: base_files = eval(compile(base_code, '', mode='eval')) else: base_files = [] - elif file_format in ('yml', 'yaml', 'json'): + elif file_format in ('.yml', '.yaml', '.json'): import mmengine cfg_dict = mmengine.load(filename) base_files = cfg_dict.get(BASE_KEY, []) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 0714734fb127334d5e7dfa54d15dcb64cb4aade3..8420fe7b211673a0adb9a093a2e74cd4c6fb3378 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -5,8 +5,10 @@ import os import os.path as osp import platform import sys +import tempfile from importlib import import_module from pathlib import Path +from unittest.mock import patch import pytest @@ -715,6 +717,21 @@ class TestConfig: cfg = Config._file2dict(cfg_file)[0] assert cfg == dict(item1=dict(a=1)) + # Simulate the case that the temporary directory includes `.`, etc. + # /tmp/test.axsgr12/. This patch is to check the issue + # https://github.com/open-mmlab/mmengine/issues/788 has been solved. + class PatchedTempDirectory(tempfile.TemporaryDirectory): + + def __init__(self, *args, prefix='test.', **kwargs): + super().__init__(*args, prefix=prefix, **kwargs) + + with patch('mmengine.config.config.tempfile.TemporaryDirectory', + PatchedTempDirectory): + cfg_file = osp.join(self.data_path, + 'config/py_config/test_py_modify_key.py') + cfg = Config._file2dict(cfg_file)[0] + assert cfg == dict(item1=dict(a=1)) + def _merge_recursive_bases(self): cfg_file = osp.join(self.data_path, 'config/py_config/test_merge_recursive_bases.py')