TensorFlow 2.0 版本安装教程+CNN图像mnist手写数字识别实战代码-CNN图像识别

原文链接:https://geektutu.com/post/tensorflow2-mnist-cnn.html

转载自极客兔兔:
https://geektutu.com/post/tensorflow2-mnist-cnn.html
基于TensorFlow 2.0 版本训练CNN网络——实现mnist手写数字识别训练

TensorFlow 2.0 版本安装教程+CNN图像mnist手写数字识别实战代码-CNN图像识别
代码目录结构
data_set_tf2/ # TensorFlow 2.0的mnist数据集
|–mnist.npz
test_images/ # 预测所用的图片
|–0.png
|–1.png
|–4.png
v4_cnn/
|–ckpt/ # 模型保存的位置
|–checkpoint
|–cp-0005.ckpt.data-00000-of-00001
|–cp-0005.ckpt.index
|–predict.py # 预测代码
|–train.py # 训练代码
3. CNN模型代码(train.py)
模型定义的前半部分主要使用Keras.layers提供的Conv2D(卷积)与MaxPooling2D(池化)函数。

CNN的输入是维度为 (image_height, image_width, color_channels)的张量,mnist数据集是黑白的,因此只有一个color_channel(颜色通道),一般的彩色图片有3个(R,G,B),熟悉Web前端的同学可能知道,有些图片有4个通道(R,G,B,A),A代表透明度。对于mnist数据集,输入的张量维度就是(28,28,1),通过参数input_shape传给网络的第一层。
下面直接上代码:



import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models


class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))

        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))

        model.summary()

        self.model = model
model.summary()用来打印我们定义的模型的结构。

我们可以看到,每一个Conv2D和MaxPooling2D层的输出都是一个三维的张量(height, width, channels)。height和width会逐渐地变小。输出的channel的个数,是由第一个参数(例如,32或64)控制的,随着height和width的变小,channel可以变大(从算力的角度)。

模型的后半部分,是定义输出张量的。layers.Flatten会将三维的张量转为一维的向量。展开前张量的维度是(3, 3, 64) ,转为一维(576)的向量后,紧接着使用layers.Dense层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。
安装教程与2.0版本的CNN图像识别模型训练全文与代码链接地址:Tensorflow2.0版本 mnest手写数字识别


《AI工匠BOOK》持续更新AI算法与最新应用,如果您感兴趣,欢迎关注AI工匠(AI算法与最新应用前沿研究)。

TensorFlow 2.0 版本安装教程+CNN图像mnist手写数字识别实战代码-CNN图像识别

上一篇:阅读理解-bidaf模型


下一篇:callback