Skip to content
Snippets Groups Projects
Unverified Commit 452b3656 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix dump config without self.filename (#202)

* fix config

* add docstring and unit test

* update tutorial

* update tutorial

* fix markdown format

* fix markdown format
parent 4dcbd269
No related branches found
No related tags found
No related merge requests found
......@@ -194,6 +194,71 @@ a = {{_base_.model}}
# 等价于 a = dict(type='ResNet', depth=50)
```
## 配置文件的导出
在启动训练脚本时,用户可能通过传参的方式来修改配置文件的部分字段,为此我们提供了 `dump`
接口来导出更改后的配置文件。与读取配置文件类似,用户可以通过 `cfg.dump('config.xxx')` 来选择导出文件的格式。`dump`
同样可以导出有继承关系的配置文件,导出的文件可以被独立使用,不再依赖于 `_base_` 中定义的文件。
基于继承一节定义的 `resnet50.py`
```python
_base_ = ['optimizer_cfg.py', 'runtime_cfg.py']
model = dict(type='ResNet', depth=50)
```
我们将其加载后导出:
```python
cfg = Config.fromfile('resnet50.py')
cfg.dump('resnet50_dump.py')
```
`dumped_resnet50.py`
```python
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
gpu_ids = [0, 1]
model = dict(type='ResNet', depth=50)
```
类似的,我们可以导出 json、yaml 格式的配置文件
`dumped_resnet50.yaml`
```yaml
gpu_ids:
- 0
- 1
model:
depth: 50
type: ResNet
optimizer:
lr: 0.02
momentum: 0.9
type: SGD
weight_decay: 0.0001
```
`dumped_resnet50.json`
```json
{"optimizer": {"type": "SGD", "lr": 0.02, "momentum": 0.9, "weight_decay": 0.0001}, "gpu_ids": [0, 1], "model": {"type": "ResNet", "depth": 50}}
```
此外,`dump` 不仅能导出加载自文件的 `cfg`,还能导出加载自字典的 `cfg`
```python
cfg = Config(dict(a=1, b=2))
cfg.dump('demo.py')
```
`demo.py`
```python
a=1
b=2
```
## 其他进阶用法
这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。
......
......@@ -12,6 +12,7 @@ import warnings
from argparse import Action, ArgumentParser, Namespace
from collections import abc
from importlib import import_module
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
from addict import Dict
......@@ -99,7 +100,8 @@ class Config:
Args:
cfg_dict (dict, optional): A config dictionary. Defaults to None.
cfg_text (str, optional): Text of config. Defaults to None.
filename (str, optional): Name of config file. Defaults to None.
filename (str or Path, optional): Name of config file.
Defaults to None.
Examples:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
......@@ -123,7 +125,8 @@ class Config:
def __init__(self,
cfg_dict: dict = None,
cfg_text: Optional[str] = None,
filename: str = None):
filename: Optional[Union[str, Path]] = None):
filename = str(filename) if isinstance(filename, Path) else filename
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
......@@ -145,13 +148,13 @@ class Config:
super().__setattr__('_text', text)
@staticmethod
def fromfile(filename: str,
def fromfile(filename: Union[str, Path],
use_predefined_variables: bool = True,
import_custom_modules: bool = True) -> 'Config':
"""Build a Config instance from config file.
Args:
filename (str): Name of config file.
filename (str or Path): Name of config file.
use_predefined_variables (bool, optional): Whether to use
predefined variables. Defaults to True.
import_custom_modules (bool, optional): Whether to support
......@@ -160,6 +163,7 @@ class Config:
Returns:
Config: Config instance built from config file.
"""
filename = str(filename) if isinstance(filename, Path) else filename
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
......@@ -675,32 +679,31 @@ class Config:
super().__setattr__('_filename', _filename)
super().__setattr__('_text', _text)
def dump(self, file: Optional[str] = None):
def dump(self, file: Optional[Union[str, Path]] = None):
"""Dump config to file or return config text.
Args:
file (str, optional): If not specified, then the object
file (str or Path, optional): If not specified, then the object
is dumped to a str, otherwise to a file specified by the filename.
Defaults to None.
Returns:
str or None: Config text.
"""
cfg_dict = super().__getattribute__('_cfg_dict').to_dict()
if self.filename.endswith('.py'):
if file is None:
file = str(file) if isinstance(file, Path) else file
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
if file is None:
if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text
else:
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
return None
else:
if file is None:
file_format = self.filename.split('.')[-1]
return dump(cfg_dict, file_format=file_format)
else:
dump(cfg_dict, file)
return None
elif file.endswith('.py'):
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
file_format = file.split('.')[-1]
return dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self,
options: dict,
......
......@@ -224,8 +224,26 @@ class TestConfig:
pkl_cfg_filename = tmp_path / '_pickle.pkl'
dump(cfg, pkl_cfg_filename)
pkl_cfg = load(pkl_cfg_filename)
assert pkl_cfg._cfg_dict == cfg._cfg_dict
# Test dump config from dict.
cfg_dict = dict(a=1, b=2)
cfg = Config(cfg_dict)
assert cfg.pretty_text == cfg.dump()
# Test dump python format config.
dump_file = tmp_path / 'dump_from_dict.py'
cfg.dump(dump_file)
with open(dump_file, 'r') as f:
assert f.read() == 'a = 1\nb = 2\n'
# Test dump json format config.
dump_file = tmp_path / 'dump_from_dict.json'
cfg.dump(dump_file)
with open(dump_file, 'r') as f:
assert f.read() == '{"a": 1, "b": 2}'
# Test dump yaml format config.
dump_file = tmp_path / 'dump_from_dict.yaml'
cfg.dump(dump_file)
with open(dump_file, 'r') as f:
assert f.read() == 'a: 1\nb: 2\n'
def test_pretty_text(self, tmp_path):
cfg_file = osp.join(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment