使用tensorflow实现cnn进行mnist识别

第一个CNN代码,暂时对于CNN的BP还不熟悉。但是通过这个代码对于tensorflow的运行机制有了初步的理解

 

 1 '''
 2 softmax classifier for mnist  
 3 
 4 created on 2019.9.28
 5 author: vince
 6 '''
 7 import math
 8 import logging
 9 import numpy  
10 import random
11 import matplotlib.pyplot as plt
12 import tensorflow as tf
13 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
14 from sklearn.metrics import accuracy_score
15 
16 def weight_bais_variable(shape):
17     init = tf.random.truncated_normal(shape = shape, stddev = 0.01);
18     return tf.Variable(init);
19 
20 def bais_variable(shape):
21     init = tf.constant(0.1, shape=shape);
22     return tf.Variable(init);
23 
24 def conv2d(x, w):
25     return tf.nn.conv2d(x, w, [1, 1, 1, 1], padding = "SAME");
26 
27 def max_pool_2x2(x):
28     return tf.nn.max_pool2d(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME");
29 
30 def cnn(x, rate):
31     with tf.name_scope('reshape'):
32         x_image = tf.reshape(x, [-1, 28, 28, 1]);
33     
34     #first layer, conv & pool 
35     with tf.name_scope('conv1'):
36         w_conv1 = weight_bais_variable([5, 5, 1, 32]);
37         b_conv1 = bais_variable([32]);
38         h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1); #28 * 28 * 32
39     with tf.name_scope('pool1'):
40         h_pool1 = max_pool_2x2(h_conv1); #14 * 14 * 32
41     
42     #second layer, conv & pool 
43     with tf.name_scope('conv2'):
44         w_conv2 = weight_bais_variable([5, 5, 32, 64]);
45         b_conv2 = bais_variable([64]);
46         h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2); #14 * 14 * 64 
47     with tf.name_scope('pool2'):
48         h_pool2 = max_pool_2x2(h_conv2);  #7 * 7 * 64 
49 
50     #first full connect layer, feature graph -> feature vector 
51     with tf.name_scope('fc1'):
52         w_fc1 = weight_bais_variable([7 * 7 * 64, 1024]);
53         b_fc1 = bais_variable([1024]);
54         h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]);
55         h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1);
56     with tf.name_scope("dropout1"):
57         h_fc1_drop = tf.nn.dropout(h_fc1, rate);
58 
59     #second full connect layer, 
60     with tf.name_scope('fc2'):
61         w_fc2 = weight_bais_variable([1024, 10]);
62         b_fc2 = bais_variable([10]);
63         #h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2;
64         h_fc2 = tf.matmul(h_fc1, w_fc2) + b_fc2;
65     return h_fc2;
66 
67 
68 def main(): 
69     logging.basicConfig(level = logging.INFO,
70             format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
71             datefmt = '%a, %d %b %Y %H:%M:%S');
72 
73     mnist = read_data_sets('../data/MNIST',one_hot=True)    # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签
74 
75     x = tf.placeholder(tf.float32, [None, 784]);
76     y_real = tf.placeholder(tf.float32, [None, 10]);
77     rate = tf.placeholder(tf.float32);
78 
79     y_pre = cnn(x, rate);
80 
81     sess = tf.InteractiveSession();
82     sess.run(tf.global_variables_initializer());
83 
84     loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_pre, labels = y_real));
85     train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss);
86 
87     correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1));
88     prediction_op= tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
89     for _ in range(300):
90         batch_xs, batch_ys = mnist.train.next_batch(128);
91         sess.run(train_op, feed_dict = {x : batch_xs, y_real : batch_ys, rate: 0.5});
92         if _ % 10 == 0: 
93             accuracy = sess.run(prediction_op, feed_dict = {x : mnist.test.images, y_real : mnist.test.labels, rate: 0.0 });
94             logging.info("%s : %s" % (_, accuracy));
95 
96 if __name__ == "__main__":
97     main();

 

上一篇:剑指offer第二版面试题5:从尾到头打印链表(JAVA版)


下一篇:手写数字识别