目录
1. 前言
2. vgg 的水果分类
2.1 训练
2.2 训练结果
2.3 推理
1. 前言
VGG网络是牛津大学研究人员提出的一种卷积神经网络(CNN)架构。
它被广泛用于图像分类和特征提取任务。VGG网络由一系列卷积层和完全连接的层组成。网络架构的特点是简单,整个网络中只使用了3x3卷积滤波器和最大池化层。VGG模型以其深度架构而闻名,VGG16和VGG19等变体分别具有16层和19层。VGG网络对后续CNN架构的发展产生了影响,并在ImageNet等图像分类基准上取得了优异的性能。
vgg网络模型的亮点:
- 以连续的3*3卷积替代更大的卷积核(7*7),经过证明,这种小的连续卷积核实现的感受野和较大的卷积一样,并且计算量要少得多
- 更简单的模型架构,完全由卷积+下采样实现,避免了以往杂乱无章的网络架构
值得一提的是,vgg模型的效果非常出色,不过因为网络很臃肿导致完全训练不动,这也是vgg最大的鸡肋。
本人认为,vgg模型还是比较成功的,首先是连续的3*3卷积可以替代更大的卷积核。几乎从vgg以后,卷积核全部换成了3*3。并且vgg模型的简单,也为后续的改进提供思路,像这种简单的卷积、下采样的堆叠就可以实现更好的效果,虽然后续的resnet提出了shortcut的瓶颈结构,但原理其实也是下采样一半,卷积核通道数翻倍,就是vgg架构的变迁而已
2. vgg 的水果分类
项目下载:图像识别项目:vgg系列网络(vgg11、vgg13、vgg16等)实现的迁移学习、图像识别项目:33种水果图像分类资源-****文库
其中,data是数据集,inference用于推理图像,runs保存训练生成的结果,train用于训练,predict用于推理,utils是需要的工具函数
这里没有提供单独的验证脚本,因为本人习惯,集成在train里面,在训练的时候一起评估了
2.1 训练
水果数据集经过处理如下:
这里是33类别的水果分类,训练集约有11k张图片,验证集约有5k张图片
标签保存在字典里(代码会自动生成)
{
"0": "Apple Braeburn",
"1": "Apple Granny Smith",
"2": "Apricot",
"3": "Avocado",
"4": "Banana",
"5": "Blueberry",
"6": "Cactus fruit",
"7": "Cantaloupe",
"8": "Cherry",
"9": "Clementine",
"10": "Corn",
"11": "Cucumber Ripe",
"12": "Grape Blue",
"13": "Kiwi",
"14": "Lemon",
"15": "Limes",
"16": "Mango",
"17": "Onion White",
"18": "Orange",
"19": "Papaya",
"20": "Passion Fruit",
"21": "Peach",
"22": "Pear",
"23": "Pepper Green",
"24": "Pepper Red",
"25": "Pineapple",
"26": "Plum",
"27": "Pomegranate",
"28": "Potato Red",
"29": "Raspberry",
"30": "Strawberry",
"31": "Tomato",
"32": "Watermelon"
}
2.2 训练结果
本人用vgg11进行训练,效果太好了,直接100%准确率....
训练可以选择的指标如下:
parser.add_argument("--model", default='vgg11', type=str,help='vgg11、vgg13、vgg16、vgg19、vgg11_bn、vgg13_bn、vgg16_bn、vgg19_bn')
parser.add_argument("--pretrained", default=True, type=bool) # 采用官方权重
parser.add_argument("--freeze_layers", default=True, type=bool) # 冻结权重
parser.add_argument("--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=10, type=int)
parser.add_argument("--optim", default='SGD', type=str,help='SGD、Adam') # 优化器选择
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--lrf',default=0.0001,type=float) # 最终学习率 = lr * lrf
需要注意的是,网络的输出是经过更改的,代码会根据数据集自动生成num classes,不需要自行设置
tmp = net.classifier[3].out_features
net.classifier[6] = torch.nn.Linear(tmp,num,bias=True)
结果在runs里面下:
这里的训练日志全部在log的json文件中,参考如下:
{
"train parameters": {
"model": "vgg11",
"pretrained": true,
"freeze_layers": true,
"batch_size": 4,
"epochs": 10,
"optim": "SGD",
"lr": 0.001,
"lrf": 0.0001
},
"total paramerters": 128901537,
"train paramerters": 119681057,
"epoch:0": {
"train info": {
"accuracy": 0.9742591024547211,
"Apple Braeburn": {
"Precision": 0.948,
"Recall": 0.9507,
"Specificity": 0.9984,
"F1 score": 0.9493
},
"Apple Granny Smith": {
"Precision": 0.9883,
"Recall": 0.9768,
"Specificity": 0.9997,
"F1 score": 0.9825
},
"Apricot": {
"Precision": 0.9791,
"Recall": 0.9507,
"Specificity": 0.9994,
"F1 score": 0.9647
},
"Avocado": {
"Precision": 0.9691,
"Recall": 0.9431,
"Specificity": 0.9992,
"F1 score": 0.9559
},
"Banana": {
"Precision": 0.9912,
"Recall": 0.9883,
"Specificity": 0.9997,
"F1 score": 0.9897
},
"Blueberry": {
"Precision": 0.9753,
"Recall": 0.9753,
"Specificity": 0.9993,
"F1 score": 0.9753
},
"Cactus fruit": {
"Precision": 0.9797,
"Recall": 0.9825,
"Specificity": 0.9994,
"F1 score": 0.9811
},
"Cantaloupe": {
"Precision": 0.9885,
"Recall": 0.9942,
"Specificity": 0.9997,
"F1 score": 0.9913
},
"Cherry": {
"Precision": 0.9883,
"Recall": 0.9768,
"Specificity": 0.9997,
"F1 score": 0.9825
},
"Clementine": {
"Precision": 0.9218,
"Recall": 0.9621,
"Specificity": 0.9976,
"F1 score": 0.9415
},
"Corn": {
"Precision": 1.0,
"Recall": 0.9905,
"Specificity": 1.0,
"F1 score": 0.9952
},
"Cucumber Ripe": {
"Precision": 0.9926,
"Recall": 0.9782,
"Specificity": 0.9998,
"F1 score": 0.9853
},
"Grape Blue": {
"Precision": 0.9717,
"Recall": 0.9956,
"Specificity": 0.9982,
"F1 score": 0.9835
},
"Kiwi": {
"Precision": 0.9698,
"Recall": 0.9817,
"Specificity": 0.9991,
"F1 score": 0.9757
},
"Lemon": {
"Precision": 0.971,
"Recall": 0.971,
"Specificity": 0.9991,
"F1 score": 0.971
},
"Limes": {
"Precision": 0.9443,
"Recall": 0.9883,
"Specificity": 0.9983,
"F1 score": 0.9658
},
"Mango": {
"Precision": 0.9561,
"Recall": 0.9534,
"Specificity": 0.9987,
"F1 score": 0.9547
},
"Onion White": {
"Precision": 0.9741,
"Recall": 0.9805,
"Specificity": 0.9993,
"F1 score": 0.9773
},
"Orange": {
"Precision": 0.9852,
"Recall": 0.9911,
"Specificity": 0.9996,
"F1 score": 0.9881
},
"Papaya": {
"Precision": 0.9431,
"Recall": 0.913,
"Specificity": 0.9983,
"F1 score": 0.9278
},
"Passion Fruit": {
"Precision": 0.9941,
"Recall": 0.9767,
"Specificity": 0.9998,
"F1 score": 0.9853
},
"Peach": {
"Precision": 0.9263,
"Recall": 0.9478,
"Specificity": 0.9977,
"F1 score": 0.9369
},
"Pear": {
"Precision": 0.9754,
"Recall": 0.9754,
"Specificity": 0.9989,
"F1 score": 0.9754
},
"Pepper Green": {
"Precision": 0.9936,
"Recall": 0.9936,
"Specificity": 0.9998,
"F1 score": 0.9936
},
"Pepper Red": {
"Precision": 0.9789,
"Recall": 0.9936,
"Specificity": 0.9991,
"F1 score": 0.9862
},
"Pineapple": {
"Precision": 0.9971,
"Recall": 0.9913,
"Specificity": 0.9999,
"F1 score": 0.9942
},
"Plum": {
"Precision": 0.9904,
"Recall": 0.984,
"Specificity": 0.9997,
"F1 score": 0.9872
},
"Pomegranate": {
"Precision": 0.9531,
"Recall": 0.942,
"Specificity": 0.9986,
"F1 score": 0.9475
},
"Potato Red": {
"Precision": 0.9773,
"Recall": 0.9556,
"Specificity": 0.9994,
"F1 score": 0.9663
},
"Raspberry": {
"Precision": 1.0,
"Recall": 0.9883,
"Specificity": 1.0,
"F1 score": 0.9941
},
"Strawberry": {
"Precision": 0.9741,
"Recall": 0.9826,
"Specificity": 0.9992,
"F1 score": 0.9783
},
"Tomato": {
"Precision": 0.9728,
"Recall": 0.9691,
"Specificity": 0.9988,
"F1 score": 0.9709
},
"Watermelon": {
"Precision": 0.997,
"Recall": 0.982,
"Specificity": 0.9999,
"F1 score": 0.9894
},
"mean precision": 0.9747666666666668,
"mean recall": 0.9735090909090911,
"mean specificity": 0.9991909090909091,
"mean f1 score": 0.9740454545454547
},
"valid info": {
"accuracy": 0.9924662965880403,
"Apple Braeburn": {
"Precision": 0.9671,
"Recall": 1.0,
"Specificity": 0.999,
"F1 score": 0.9833
},
"Apple Granny Smith": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Apricot": {
"Precision": 0.9545,
"Recall": 1.0,
"Specificity": 0.9986,
"F1 score": 0.9767
},
"Avocado": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Banana": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Blueberry": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Cactus fruit": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Cantaloupe": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Cherry": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Clementine": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Corn": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Cucumber Ripe": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Grape Blue": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Kiwi": {
"Precision": 0.9789,
"Recall": 1.0,
"Specificity": 0.9994,
"F1 score": 0.9893
},
"Lemon": {
"Precision": 0.9735,
"Recall": 1.0,
"Specificity": 0.9992,
"F1 score": 0.9866
},
"Limes": {
"Precision": 0.9932,
"Recall": 1.0,
"Specificity": 0.9998,
"F1 score": 0.9966
},
"Mango": {
"Precision": 1.0,
"Recall": 0.9932,
"Specificity": 1.0,
"F1 score": 0.9966
},
"Onion White": {
"Precision": 1.0,
"Recall": 0.9924,
"Specificity": 1.0,
"F1 score": 0.9962
},
"Orange": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Papaya": {
"Precision": 0.9859,
"Recall": 0.9524,
"Specificity": 0.9996,
"F1 score": 0.9689
},
"Passion Fruit": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Peach": {
"Precision": 0.9074,
"Recall": 1.0,
"Specificity": 0.9969,
"F1 score": 0.9515
},
"Pear": {
"Precision": 0.9951,
"Recall": 0.9808,
"Specificity": 0.9998,
"F1 score": 0.9879
},
"Pepper Green": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Pepper Red": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Pineapple": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Plum": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Pomegranate": {
"Precision": 1.0,
"Recall": 0.9252,
"Specificity": 1.0,
"F1 score": 0.9611
},
"Potato Red": {
"Precision": 1.0,
"Recall": 0.8963,
"Specificity": 1.0,
"F1 score": 0.9453
},
"Raspberry": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Strawberry": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Tomato": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"Watermelon": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"mean precision": 0.9925939393939395,
"mean recall": 0.9921303030303031,
"mean specificity": 0.9997666666666667,
"mean f1 score": 0.992121212121212
}
},
仅仅跑了几轮,已经100准确率了
loss和acc曲线:
其他的评估指标曲线:
训练集和测试集的混淆矩阵:
2.3 推理
推理需要predict脚本,设定的参数在下面,model要保证和训练的版本一样
parser.add_argument("--model", default='vgg11', type=str,help='vgg11、vgg13、vgg16、vgg19、vgg11_bn、vgg13_bn、vgg16_bn、vgg19_bn')
parser.add_argument("--weights", default='runs/weights/best.pth', type=str, help='best、last')
只需要把想要推理的数据放在infer_img下即可
运行即可推理 :