本文在很大程度上参考了 博主@王不对 的文章《如何方便地同时使用命令行参数和配置文件指定程序参数》,在此表示感谢
出发点
在进行深度学习的代码实验时,有很多超参数及配置项需要经常变动,因此要求程序的可配置性一定要高。
我们既希望能在程序执行时灵活的指定每一个参数,又希望大部分参数能有一个默认值、省去了每次运行时的重复指定。
前者,可以通过命令行参数解析的方式进行,在python中一般利用 argparse 来实现;后者,可以通过直接在argparse 中为参数设定默认值或者使用配置文件(一般利用yaml、toml、json等配置文件)来实现。
然而,上述两种方法各自都有一定的缺点。
命令行方式
-
命令行参数无法直接表达“层次”
python main.py --optimizer_name adam --optimizer_lr 1e-4
显然,参数在语义上是有层次的,“平铺”的方式不适合表达大量参数信息
-
新增参数时需要通过修改源代码来进行解析指定
parser.add_argument('--max_epoch', type=int, default=100)
如果我有20个参数要指定,则需要写20行类似的代码!
配置文件方式
-
如果要同时进行多个不同参数的实验,则必须写多个配置文件
即便多次实验中每次只有一个参数不同,也得写一个基本上一样的配置文件!
虽然每次实验都写一个配置文件可以方便事后仔细分析,但是并不方便!
因此,单一的配置方式并不能很好地满足我们的需求,我们需要更便捷的配置方式。
解决思路
一种很容易想到的方式是 “同时使用命令行参数以及配置文件”,即用命令行参数来动态修改配置文件
另外,关于“命令行参数表达层次”的问题,可以利用 toml 配置格式,其中最重要的特性是“点分隔键”:
下面是两行toml配置字符串:
training.optimizer.name=adam
training.optimizer.lr=0.0001
在 python 中,它可以被很容易地按照层级解析成一个嵌套字典
{’training': {'optimizer': {'name': 'adam', 'lr': 0.0001}}}
这样,我们可以用 argprase 输入这种 toml 点分配置字符串,然后利用解析得到的配置字典去更新由 yaml 配置文件得到的配置字典
举个例子, 命令行输入为
python main.py --toml_cfg "training.optimizer.name='adam'" training.optimizer.lr=0.0001
注意,training.optimizer.name=‘adam’ 外加引号“ 是因为 不加引号时 toml 无法将 ‘adam’ 载入为有效的字符串
而配置文件为
# config.yaml
model:
name: 'resnet'
in_ch: 3
out_ch: 1
training:
batch_size: 64
max_epoch: 100
loss_fn: 'ce'
optimizer:
name: 'sgd'
lr: 0.001
最终,得到的配置字典为
{
'model': {
'name': 'resnet',
'in_ch': 3,
'out_ch': 1
},
'training':{
'batch_size': 64,
'max_epoch': 100,
'loss_fn': 'ce',
'optimizer': {
'name': 'adam',
'lr': 0.0001
}
}
}
具体方案
1. 载入相关代码库
import toml
import yaml
import argparse
2. 递归更新配置字典
def deep_update(cfg, new_cfg, __path=''):
for k, v in new_cfg.items():
if cfg.get(k, None) is None:
print(f'Invalid cfg option: {__path[1:]}.{k}')
continue
if isinstance(v, dict):
deep_update(cfg[k], v, '.'.join([__path, str(k)]))
else:
cfg[k] = v
3. 载入参数并进行动态更新
def load_cfg():
# 解析命令行参数
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-c', '--cfg_path', type=str, default='config.yaml')
parser.add_argument('-nc', '--yaml_cfg', type=str, nargs='*', default=None)
args = parser.parse_args()
# 载入配置文件
with open('config.yaml') as fp:
cfg = yaml.load(fp, Loader=yaml.CLoader)
# 动态更新配置字典
if args.new_cfg is not None:
new_cfg = toml.loads('\n'.join(args.new_cfg))
deep_update(cfg, new_cfg)
return cfg
4. 主调函数
if __name__ == "__main__":
cfg = load_cfg()
print(f'cfg={cfg}')
# exp_name = get_exp_name(cfg)
# train_model(cfg, exp_name)