基于TensorFlow的MNIST手写数字识别-初级

一:MNIST数据集   

下载地址

MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件

基于TensorFlow的MNIST手写数字识别-初级

分别是test set images,test set labels,training set images,training set labels

training set包括60000个样本,test set包括10000个样本。

test set中前5000个样本来自原始的NISTtraining set,后5000个样本来自原始的NIST test set,因此,前5000个样本比后5000个样本更简单和干净。

每个样本是28*28像素的图片

基于TensorFlow的MNIST手写数字识别-初级

基于TensorFlow的MNIST手写数字识别-初级

基于TensorFlow的MNIST手写数字识别-初级

基于TensorFlow的MNIST手写数字识别-初级

二:tensorflow构建模型识别MNIST

导入数据:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10]) #真实值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #预测值

softmax的目的:将输出转化为是每个数字的概率

#计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

交叉熵:衡量预测值与真实值之间的差别,当然是越小越好

公式为:

基于TensorFlow的MNIST手写数字识别-初级

其中y'是真实值,y为预测值

最后用梯度下降法优化参数即可

在Session中运行graph:

total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y})

 预测正确率:

correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

tf.argmax()函数返回axis轴上最大值的index

tf.equal()函数返回的是布尔值,需要用tf.cast()方法转为tf.float32类型

最后在test set上进行预测:

step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))

完整代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#eval
correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y}) step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))

运行结果:

基于TensorFlow的MNIST手写数字识别-初级

准确率为0.92左右

后面我们会构建更好的模型达到更高的正确率。

相关链接:

详解 MNIST 数据集

基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型

基于tensorflow的MNIST手写数字识别(二)--入门篇

基于tensorflow的MNIST手写数字识别(三)--神经网络篇

上一篇:<转>DNS服务系列之二:DNS区域传送漏洞的安全案例


下一篇:【CAS单点登录视频教程】 第04集 -- tomcat下配置https环境