做这个事情是为了拉通流程,所以对数据集进行了重新制作,只选取了部分类别来完成目标分类任务。
代码地址:
完整数据集:https://pan.baidu.com/wap/init?surl=6f1nQchS-zBtzSWn9Guyyg 密码:iksk
重新制作后的:
一.项目环境
python3.6 torch1.6 torchvision==0.7.0
二.数据集介绍
首先说下数据集:AI Challenger 2018农作物病害检测竞赛就是由上海新客科技为竞赛提供农作物叶子图像的数据集:标注图片5万张,包含10种植物(苹果、樱桃、葡萄、柑桔、桃、草莓、番茄、辣椒、玉米、马铃薯)的27种病害,合计61个分类(按“物种-病害-程度”分)
标签类别对照表:见https://github.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/blob/master/README.md
标注文件为json文件,是一个列表,列表中每张图片的信息以字典形式保存。key值:“disease_class”, “image_id”。
三.数据集制作
首先参考这篇进行的数据分布的分析,选择了9-16玉米的病害数据来完成本次任务。
with open(train_data_json) as datafile1:
trainDataFram=pd.read_json(datafile1,orient='records')
with open(val_data_json) as datafile2: #first check if it's a valid json file or not
validateDataFram =pd.read_json(datafile2,orient='records')
total=trainDataFram.isnull().sum().sort_values(ascending=False)
percent=(trainDataFram.isnull().sum())/(trainDataFram.isnull().count()).sort_values(ascending = False)
missing_validation_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'],sort=False)
# print(missing_validation_data.head())
dataDistribute=trainDataFram.groupby(by=['disease_class']).size()
# print(dataDistribute)
plt.figure(figsize=(50,20),dpi=100)
plt.xticks(range(len(dataDistribute)),dataDistribute.index.tolist(),fontsize=40)
plt.yticks(fontsize=40)
bar=plt.bar(dataDistribute.index.tolist(), dataDistribute.tolist(),width=0.7)
for b in bar:
h=b.get_height()
plt.text(b.get_x()+b.get_width()/2,h,int(h),ha='center',fontsize=30)
plt.show()
数据集制作使用的pandas进行数据筛选,然后生成train和val文件夹,里面包含标签9-16命名的图片文件夹,然后将对应类别的图片导入文件夹,比较简单详见git。
四.训练
训练脚本来自:https://github.com/pytorch/examples/tree/master/imagenet
根据自己的情况进行修改:
1.开头argparse添加数据集路径
parser.add_argument('-data', default='/opt/yyl/data/plant2018/', metavar='DIR',help='path to dataset')
2.简便起见,resnet50写入default
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',choices=model_names, help='model architecture: ' +' | '.join(model_names) +' (default: resnet18)')
3.使用预训练权重
parser.add_argument('--pretrained', default='True',dest='pretrained', action='store_true', help='use pre-trained model')
4.创建模型后面修改全连接层输出
#fc modified
model.fc = torch.nn.Linear(in_features=2048,out_features=8,bias=True)
print(model)
5.同样在开头设置batchsize、epoch、指定gpu等就不赘述
好了,然后直接运行train.py就可以开始训练了。
本次训练的最优结果:
青古の每篇一歌
《外婆》
外婆她的无奈
无法变成期待
只有爱才能够明白