win10系统anaconda的notebook的cifar10离线下载、数据加载及CNN训练

1、官网数据下载

有时会受到网络限制不能直接加载cifar10数据,需要下载离线数据包,官方网址如下:

https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

2、压缩包重命名与解压

将压缩包放置user/xxx/.keras/datasets下,将cifar-10-batches-py.tar.gz直接解压,在datasets目录下新建文件夹cifar-10-batches-py,将解压的全部文件(不包括文件夹)拷贝至这个文件夹下。

3、加载数据

导入from tensorflow.keras import datasets

读取数据:

(x_train,y_train), (x_test,y_test) = datasets.cifar10.load_data()
x_train,x_test = x_train/255.0, x_test/255.0

4、cifar10数据的CNN训练(代码主要来自https://blog.csdn.net/yanghe4405/article/details/107521797)

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras import datasets
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)

#cifar10 = tf.keras.datasets.cifar10
#(x_train,y_train), (x_test,y_test) = cifar10.load_data()
(x_train,y_train), (x_test,y_test) = datasets.cifar10.load_data()
x_train,x_test = x_train/255.0, x_test/255.0

class Baseline(Model):
    def __init__(self):
        #'在此准备出搭建神经网络要用的每一层结构,即CBAPD'
        super(Baseline, self).__init__()
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same') 
        self.b1 = BatchNormalization()  # BN层
        self.a1 = Activation('relu')  # 激活层
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')  # 池化层
        self.d1 = Dropout(0.2)  # dropout层

        self.flatten = Flatten()
        self.f1 = Dense(128, activation='relu')
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation='softmax')
    def call(self, x):
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y


model = Baseline()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=20, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
model.summary()

# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

  

上一篇:Vue 路由history模式nginx部署(二级目录)


下一篇:python启动时Failed calling sys.__interactivehook__错误原因及解决方法