day13-TensorFlow简单神经网络实现手写数字识别


# coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

def numberRead():
    # 获取数据
    mnist = input_data.read_data_sets("../data/day06/", one_hot=True)

    # 1、准备数据集
    with tf.variable_scope("data"):
        # 准备占位符
        x = tf.placeholder(tf.float32,shape=[None,784])
        y_true = tf.placeholder(tf.int64,shape=[None,10])

        # 构建一个全连接层的网络,即权重和偏置
        weight = tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
        bias = tf.Variable(tf.random_normal([10],mean=0.0,stddev=1.0))

    # 2、构建模型
    with tf.variable_scope("model"):
        # None*784 乘 784*10 得到的结果为 None*10 即对应十个目标值
        y_predict = tf.matmul(x,weight) + bias

    # 3、模型参数计算
    with tf.variable_scope("model_soft_corss"):
        # 计算交叉熵损失
        softmax = tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict)
        # 计算损失平均值
        loss = tf.reduce_mean(softmax)

    # 4、梯度下降(反向传播算法)优化模型
    with tf.variable_scope("model_better"):
        tarin_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

    # 5、计算准确率
    with tf.variable_scope("model_acc"):
        # 计算出每个样本是否预测成功,结果为:[1,0,1,0,0,0,....,1]
        equal_list = tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1))

        # 计算出准确率,先将预测是否成功换为float可以得到详细的准确率
        acc = tf.reduce_mean(tf.cast(equal_list,tf.float32))


    # 6、准备工作
    # 定义变量初始化op
    init_op = tf.global_variables_initializer()
    # 定义哪些变量记录
    tf.summary.scalar("losses",loss)
    tf.summary.scalar("acces",acc)
    tf.summary.histogram("weightes",weight)
    tf.summary.histogram("biases",bias)
    merge = tf.summary.merge_all()

    # 开启会话运行
    with tf.Session() as sess:
        # 变量初始化
        sess.run(init_op)

        # 开启记录
        filewriter = tf.summary.FileWriter("../summary/day06/",graph=sess.graph)

        for i in range(2500):
            # 准备数据
            mnist_x, mnist_y = mnist.train.next_batch(50)

            # 开始训练
            sess.run([tarin_op],feed_dict={x:mnist_x,y_true:mnist_y})

            # 得出训练的准确率,注意还需要将数据填入
            print("第%d次训练,准确率为:%f" % ((i+1),sess.run(acc, feed_dict={x: mnist_x, y_true: mnist_y})))

            # 写入每步训练的值
            summary = sess.run(merge,feed_dict={x:mnist_x,y_true:mnist_y})
            filewriter.add_summary(summary,i)




    return None


if __name__ == '__main__':
    numberRead()





mnist数据集获取地址:http://yann.lecun.com/exdb/mnist/

训练效果:
{{uploading-image-661248.png(uploading...)}}

上一篇:Day13 PythonWeb全栈课程课堂内容


下一篇:linux day13 用户基本概述