Skip to content
Snippets Groups Projects
Unverified Commit ad1b43fa authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix `Config` cannot parse base config when there is `.` in tmp path (#856)


* [Fix] Fix config cannot parse tmp path like

* Add comments

* Add comments

* Apply suggestions from code review

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 6af88783
No related branches found
No related tags found
No related merge requests found
......@@ -550,8 +550,8 @@ class Config:
Returns:
list: A list of base config.
"""
file_format = filename.partition('.')[-1]
if file_format == 'py':
file_format = osp.splitext(filename)[1]
if file_format == '.py':
Config._validate_py_syntax(filename)
with open(filename, encoding='utf-8') as f:
codes = ast.parse(f.read()).body
......@@ -568,7 +568,7 @@ class Config:
base_files = eval(compile(base_code, '', mode='eval'))
else:
base_files = []
elif file_format in ('yml', 'yaml', 'json'):
elif file_format in ('.yml', '.yaml', '.json'):
import mmengine
cfg_dict = mmengine.load(filename)
base_files = cfg_dict.get(BASE_KEY, [])
......
......@@ -5,8 +5,10 @@ import os
import os.path as osp
import platform
import sys
import tempfile
from importlib import import_module
from pathlib import Path
from unittest.mock import patch
import pytest
......@@ -715,6 +717,21 @@ class TestConfig:
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))
# Simulate the case that the temporary directory includes `.`, etc.
# /tmp/test.axsgr12/. This patch is to check the issue
# https://github.com/open-mmlab/mmengine/issues/788 has been solved.
class PatchedTempDirectory(tempfile.TemporaryDirectory):
def __init__(self, *args, prefix='test.', **kwargs):
super().__init__(*args, prefix=prefix, **kwargs)
with patch('mmengine.config.config.tempfile.TemporaryDirectory',
PatchedTempDirectory):
cfg_file = osp.join(self.data_path,
'config/py_config/test_py_modify_key.py')
cfg = Config._file2dict(cfg_file)[0]
assert cfg == dict(item1=dict(a=1))
def _merge_recursive_bases(self):
cfg_file = osp.join(self.data_path,
'config/py_config/test_merge_recursive_bases.py')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment