使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

 
import sys,os
sys.path.append(os.pardir)
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
import tensorflow as tf def predict():
meta_path = 'ckpt/mnist.ckpt.meta'
model_path = 'ckpt/mnist.ckpt'
sess = tf.InteractiveSession ()
saver = tf.train.import_meta_graph (meta_path)
saver.restore (sess, model_path)
graph = tf.get_default_graph ()
W = graph.get_tensor_by_name ("w:0")
b = graph.get_tensor_by_name ("b:0")
x = tf.placeholder (tf.float32, [None, 784])
y = tf.nn.softmax (tf.matmul (x, W) + b)
keep_prob = tf.placeholder (tf.float32)
batch_xs, batch_ys=mnist.train.next_batch (100)
one_img = batch_xs[0].reshape ((1, 784))
one_num = batch_ys[0].reshape ((1, 10))
temp = sess.run (y, feed_dict={x: one_img, keep_prob: 1.0})
b = sess.run (tf.argmax (temp, 1))
a = sess.run (tf.arg_max (one_num, 1))
print(temp)
print(one_num)
if b == a:
print ("success! the num is :", (b[0]))
showImgTest(one_img)
else:
print ("mistakes predict.") def trainNet():
x = tf.placeholder (tf.float32, [None, 784])
W = tf.Variable (tf.zeros ([784, 10]),name="w")
b = tf.Variable (tf.zeros ([10]),name="b")
y = tf.nn.softmax (tf.matmul (x, W) + b)
y_ = tf.placeholder (tf.float32, [None, 10])
keep_prob = tf.placeholder (tf.float32)
# 定义测试的准确率
correct_prediction = tf.equal (tf.argmax (y, 1), tf.argmax (y_, 1))
accuracy = tf.reduce_mean (tf.cast (correct_prediction, tf.float32))
#
saver = tf.train.Saver (max_to_keep=1)
max_acc = 0
train_accuracy = 0
#交叉熵
cross_entropy = tf.reduce_mean (-tf.reduce_sum (y_ * tf.log (y)))
# cross_error=cross_entropy_error_batch(y,y_)
train_step = tf.train.GradientDescentOptimizer (0.01).minimize (cross_entropy)
sess = tf.InteractiveSession ()
tf.global_variables_initializer ().run ()
for i in range (1000):
batch_xs, batch_ys = mnist.train.next_batch (100)
sess.run (train_step, feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 1.0})
if i % 100 == 0:
train_accuracy = accuracy.eval (feed_dict={x: batch_xs, y_: batch_ys, keep_prob: 1.0})
print ("step %d, training accuracy %g" % (i, train_accuracy))
if train_accuracy > max_acc:
max_acc = train_accuracy
saver.save (sess, 'ckpt/mnist.ckpt') if __name__ == '__main__':
mnist = input_data.read_data_sets ("MNIST_data/", one_hot=True)
choice=0
while choice == 0:
print ("------------------------tensorflow--------------------------")
print ("\t\t\t1\ttrain model..")
print("\t\t\t2\tpredict model")
print("\t\t\t3\tshow the first image")
print ("\t\t\t0\texit")
choice = input ("please input your choice!")
print(choice)
if choice == "1":
print("start train...")
trainNet()
if choice=="2":
predict()
if choice=="3":
showImg()

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】使用tensorflow进行mnist数字识别【模型训练+预测+模型保存+模型恢复】

注:正在学习CNN,选项4还没有来的及做。后面补上

上一篇:微信小程序——代码片段汇集


下一篇:BZOJ 1022: [SHOI2008]小约翰的游戏John (Anti-nim)