TensorFlow学习笔记--- 使用CPABD实现最简单的CNN模型

import os
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.datasets import cifar10
from tensorflow.python.keras.layers import Flatten, Dense, Conv2D, BatchNormalization, AvgPool2D, Activation, MaxPool2D, \
    Dropout

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train, x_test = x_train/255.0, x_test/255.0

checkpoint_save_path = './checkpoint/model.ckpt'


# 搭建模型类, 口诀:CBAPD,C卷积B批标准化A激活P池化D全连接
class ConvModel(Model):
    def __init__(self):
        super(ConvModel, self).__init__()
        # filters: 卷积核个数  kernel_size:卷积核尺寸  strides:横纵向步长  padding:是否使用全零填充,same为是 activation:激活函数
        self.conv1 = Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), padding='same', activation=None)
        # 在激活函数前,先进行一次批标准化,使得输入值更靠近0均值
        self.bn = BatchNormalization()
        # 激活函数
        self.activation = Activation('relu')
        # 池化,减少输入特征值
        self.pool = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
        # Dropout防止过拟合
        self.dropout1 = Dropout(0.2)

        # 特征抽取完,拉直维度后通过全连接层输出
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.dropout2 = Dropout(0.2)
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.activation(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.dropout2(x)
        y = self.d2(x)
        return y


model = ConvModel()

# 模型优化
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.sparse_categorical_crossentropy,
              metrics=['sparse_categorical_accuracy'])

# callback保存模型
model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,
                                                    save_best_only=True)

# 曾经保存过,直接加载权重参数
if os.path.exists(checkpoint_save_path + '.index'):
    model.load_weights(checkpoint_save_path)

# 开始训练
model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback])

# 结果总览
model.summary()

 

上一篇:将文本文件中的\n字符串变成换行符


下一篇:八、ResNet的网络结构及其代码实现(花的三分类)