更便捷的参数配置方式:“ 命令行参数 + 配置文件 ”

本文在很大程度上参考了 博主@王不对 的文章《如何方便地同时使用命令行参数和配置文件指定程序参数》,在此表示感谢

出发点

在进行深度学习的代码实验时,有很多超参数及配置项需要经常变动,因此要求程序的可配置性一定要高。

我们既希望能在程序执行时灵活的指定每一个参数,又希望大部分参数能有一个默认值、省去了每次运行时的重复指定。

前者,可以通过命令行参数解析的方式进行,在python中一般利用 argparse 来实现;后者,可以通过直接在argparse 中为参数设定默认值或者使用配置文件(一般利用yaml、toml、json等配置文件)来实现。

然而,上述两种方法各自都有一定的缺点。

命令行方式

  1. 命令行参数无法直接表达“层次”

    python main.py --optimizer_name adam --optimizer_lr 1e-4
    

    显然,参数在语义上是有层次的,“平铺”的方式不适合表达大量参数信息

  2. 新增参数时需要通过修改源代码来进行解析指定

    parser.add_argument('--max_epoch', type=int, default=100)
    

    如果我有20个参数要指定,则需要写20行类似的代码!

配置文件方式

  1. 如果要同时进行多个不同参数的实验,则必须写多个配置文件

    即便多次实验中每次只有一个参数不同,也得写一个基本上一样的配置文件!

    虽然每次实验都写一个配置文件可以方便事后仔细分析,但是并不方便!

因此,单一的配置方式并不能很好地满足我们的需求,我们需要更便捷的配置方式。

解决思路

一种很容易想到的方式是 “同时使用命令行参数以及配置文件”,即用命令行参数来动态修改配置文件

另外,关于“命令行参数表达层次”的问题,可以利用 toml 配置格式,其中最重要的特性是“点分隔键”:

下面是两行toml配置字符串:

training.optimizer.name=adam
training.optimizer.lr=0.0001

在 python 中,它可以被很容易地按照层级解析成一个嵌套字典

{’training': {'optimizer': {'name': 'adam', 'lr': 0.0001}}}

这样,我们可以用 argprase 输入这种 toml 点分配置字符串,然后利用解析得到的配置字典去更新由 yaml 配置文件得到的配置字典

举个例子, 命令行输入为

python main.py --toml_cfg "training.optimizer.name='adam'" training.optimizer.lr=0.0001

注意,training.optimizer.name=‘adam’ 外加引号“ 是因为 不加引号时 toml 无法将 ‘adam’ 载入为有效的字符串

而配置文件为

# config.yaml
model:
	name: 'resnet'
	in_ch: 3
	out_ch: 1
training:
	batch_size: 64
	max_epoch: 100
	loss_fn: 'ce'
	optimizer:
		name: 'sgd'
		lr: 0.001

最终,得到的配置字典为

{
    'model': {
        'name': 'resnet',
        'in_ch': 3,
        'out_ch': 1
    },
    'training':{
        'batch_size': 64,
        'max_epoch': 100,
        'loss_fn': 'ce',
        'optimizer': {
            'name': 'adam',
            'lr': 0.0001
        }
    }
}

具体方案

1. 载入相关代码库

import toml
import yaml
import argparse

2. 递归更新配置字典

def deep_update(cfg, new_cfg, __path=''):
    for k, v in new_cfg.items():
        if cfg.get(k, None) is None:
            print(f'Invalid cfg option: {__path[1:]}.{k}')
            continue
        if isinstance(v, dict):
            deep_update(cfg[k], v, '.'.join([__path, str(k)]))
        else:
            cfg[k] = v

3. 载入参数并进行动态更新

def load_cfg():
    # 解析命令行参数
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-c', '--cfg_path', type=str, default='config.yaml')
    parser.add_argument('-nc', '--yaml_cfg', type=str, nargs='*', default=None)
    args = parser.parse_args()
	
    # 载入配置文件
    with open('config.yaml') as fp:
        cfg = yaml.load(fp, Loader=yaml.CLoader)
    
    # 动态更新配置字典
    if args.new_cfg is not None:
        new_cfg = toml.loads('\n'.join(args.new_cfg))
        deep_update(cfg, new_cfg)

    return cfg

4. 主调函数

if __name__ == "__main__":
    cfg = load_cfg()
    print(f'cfg={cfg}')
    # exp_name = get_exp_name(cfg)
    # train_model(cfg, exp_name)
上一篇:torch 保存模型


下一篇:pytorch 断点续训练