环境:TensorFlow==1.14.0 python3.7
Inception 模型优势
(1)分解成小卷积很有效,可以降低参数量,减轻过拟合,增加网络非线性的表达能力。
(2) 卷积网络从输入到输出,应该让图片尺寸逐渐减小,输出通道数逐渐增加,即让空间结构化,将空间信息转化为高阶抽象的特征信息。
(3) Inception Module用多个分支提取不同抽象程度的高阶特征的思路很有效,可以丰富网络的表达能力
详细可看inceptionV3论文《Rethinking the Inception Architecture for Computer Vision》
迁移学习
迁移学习在实际应用中的意义非常大,它可以将之前已学过的知识(模型参数)迁移到一项新的任务上,使学习效率大大的提高。我们知道,要训练一个复杂的深度学习模型,成本是十分巨大的。而迁移学习可以大大的降低我们的训练成本,在短时间内就能达到很好的效果。
这次展示的是基于垃圾分类训练,数据集来自Garbage Classification (12 classes) | Kaggle
采用对ImageNet训练过的权重做预训练模型,只对softmax训练
一、准备阶段
下载retrain.py
之前位置在
如果之后位置变化可以到README.md中查看
数据集准备:
自己拍照想要识别的物体,建议是纯色背景或者单一背景
也可以到Kaggle: Your Machine Learning and Data Science Community
然后建立以下目录
data放置数据集,如
tmp放置训练过程中产生的文件
二、训练模型
在retrain.py同级目录下打开cmd
输入python retrain.py --image_dir data --how_many_training_steps 1000 --model_dir inception_dec_2015 --output_graph output_graph.pd --output_labels output_labels.txt --bottleneck_dir tmp\bottleneck --summaries_dir tmp\retrain_logs
这里data是上述的data文件夹位置,建议都改为绝对位置,防止意外报错
具体参数可以查看retrain.py
训练结果
结束后可以在retrain.py同级目录下
inception_dec_2015下是下载的ImageNet训练模型,bottleneck是放置样本描述文件,retrain_logs放置训练日志,可以通过TensorBoard查看,output_graph.pd就是我们训练好的模型,output_labels.txt是标签
三、模型测试
# coding: UTF-8
import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt
import random
# 结果数组与output_labels.txt文件中的顺序要一致
res = [
'battery', 'biological', 'brown glass', 'cardboard', 'clothes',
'green glass', 'metal', 'paper', 'plastic', 'shoes', 'trash', 'white glass'
]
path = r'test.jpg'
with tf.gfile.FastGFile('output_graph.pd', 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name(
'final_result:0'
) # 获取新模型最后的输出节点叫做final_result,可以从tensorboard中的graph中看到,其中名字后面的’:’之后接数字为EndPoints索引值(An operation allocates memory for its outputs, which are available on endpoints :0, :1, etc, and you can think of each of these endpoints as a Tensor.),通常情况下为0,因为大部分operation都只有一个输出。
image_data = tf.gfile.FastGFile(
path, 'rb').read() # Returns the contents of a file as a string.
predictions = sess.run(
softmax_tensor,
{'DecodeJpeg/contents:0': image_data
}) # tensorboard中的graph中可以看到DecodeJpeg/contents是模型的输入变量名字
predictions = np.squeeze(predictions)
top_k = predictions.argsort()[-2:][::-1][0]
print(res[top_k])
运行后会打印置信度最高的物体标签
旧版本retrain.py的只针对对inceptionV3进行迁移学习训练,如果需要训练其他网络结构,得更新retrain.py一些参数,新版本就可以对其他网络结构训练。