配置文件
本项目使用json格式的配置文件,这种格式可以直接用json包解析使用比较方便。
在config文件夹下创建config.yml文件,这个文件包含了一些基础配置和超参的设置,因为从头写项目这些参数还不确定可以先复制一份完整的过来,后续字形修改。
GPUID: 0
WORKERS: 1
PRINT_FREQ: 10
SAVE_FREQ: 10
PIN_MEMORY: False
OUTPUT_DIR: 'output'
CUDNN:
BENCHMARK: True
DETERMINISTIC: False
ENABLED: True
DATASET:
DATASET: 360CC
ROOT: "/home/xmy/pytorch_code/dataset/"
CHAR_FILE: '/home/xmy/pytorch_code/dataset/french_img/french_dict.txt'
JSON_FILE: {'train': 'french_img/french_training_mlt_2017.txt', 'val': 'lib/dataset/txt/test.txt'}
SCALE_FACTOR: 0.25
ROT_FACTOR: 30
STD: 0.193
MEAN: 0.588
ALPHABETS: ''
TRAIN:
BATCH_SIZE_PER_GPU: 32
SHUFFLE: True
BEGIN_EPOCH: 0
END_EPOCH: 100
RESUME:
IS_RESUME: False
FILE: ''
OPTIMIZER: 'adam'
LR: 0.0001
WD: 0.0
LR_STEP: [60, 80]
LR_FACTOR: 0.1
MOMENTUM: 0.0
NESTEROV: False
RMSPROP_ALPHA:
RMSPROP_CENTERED:
FINETUNE:
IS_FINETUNE: true
FINETUNE_CHECKPOINIT: 'output/checkpoints/mixed_second_finetune_acc_97P7.pth'
FREEZE: true
TEST:
BATCH_SIZE_PER_GPU: 16
SHUFFLE: True # for random test rather than test on the whole validation set
NUM_TEST_BATCH: 1000
NUM_TEST_DISP: 10
MODEL:
NAME: 'crnn'
IMAGE_SIZE:
OW: 280 # origial width: 280
H: 32
W: 160 # resized width: 160
NUM_CLASSES: 0
NUM_HIDDEN: 256
这种文件解析完比如需要获取MODEL项下面的NUM_CLASSES项的值时很简单,只需要读取config.MODEL.NUM_CLASSES
的值便可。
编写解析函数
解析函数属于功能型函数,因此讲改代码放在utils下的utils.py文件内:
import yaml
from easydict import EasyDict as edict
def load_yml(filePath):
assert(os.path.isfile(filePath))
# 载入配置
with open(filePath, 'r') as f:
config = yaml.load(f)
config = edict(config)
char_file = config.DATASET.CHAR_FILE
# 载入字典,计算类别
with open(char_file, 'r',encoding="utf-8") as file:
config.DICT = {char.strip():num for num, char in enumerate(file.readlines())}
config.MODEL.NUM_CLASSES = len(config.DICT)
return config
在train.py里面测试改代码:
import torch.utils.data as data
import numpy as np
import time
import cv2
import sys
import os
sys.path.append(os.getcwd())
import argparse
import model.model as crnn
import torch
import torch.optim as optim
from utils.utils import load_yml
def parse_arg():
parser = argparse.ArgumentParser(description="train crnn")
parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
args = parser.parse_args()
config = load_yml(args.cfg)
return config
if __name__ == "__main__":
config = parse_arg()
print(config)
- 重点注意这句import:from utils.utils import load_yml,其他的无关紧要
- 使用–cfg参数来读取配置文件路径
- 记得修改这些参数
执行命令:
python3 train.py --cfg config/config.yml
输出如下:
可以看到配置文件解析完成~,接下来加载数据。