pipeline
1)尝试读取少量图片,检查图像格式,并且进行可视化,对数据做个大概了解;
2)划分数据集(训练集、验证集、测试集),并且分别统计下各项指标:图像分辨率、均值、方差等,可以了解到它们之间的数据分布是否接近:
224,32 imagenet
3)实现数据读取1尝试单张与批次读取,以测试是否有bug:
4)搭建模型:
5)实现损失函数;
6)选择优化算法、学习率更新策略(虽然大部分框架已有很方便的接口,但实际任务时仍有较大可能需要自定义实现
7)编写训练pipepline:调用模型的前向过程计算loss,基于框架的接口计算梯度(有可能在其中需要对梯度进行处理,如梯度截断等),调用优化器更新核型参数,清空前二批次的梯度累积等:-
8)训练送代到一定次数时在验证集上进行验证,需要实现验证过程的pipeline与训练过程类你,但不需要进行反向传播更新模型参数,只需基于评估指标
(如精度进行评估,而评估指标可以调用API或者自定义实现;
9)训练完毕后保存模型参数,实现推断过程,通常还包括可视化结果.
良好的编程规范
1.把参数放到配置文件中,保持功能函数不要改动
2.
数据集
数据集简介:
102类花朵数据集总共102种类别的花朵组成,每个类别数量40至258不等,如图所示 下载地址: https://www.robots.ox.ac.uk/~vgg/data/flowers/102/
102flowers.tgz:纯图片,无标签
imagelabels.mat:图片对应的标签
To do List 将数据集重新组织,以类别为文件夹进行划分 划分训练集、验证集和测试集