在学习分类的时候,mnist
数据集分类尝试时,因为训练维度的原因,教程中都是
model.predict([[X_train[0]]])
这里说明是:因为训练维度为 3 维,所以需要添加 2 层 [ ]
, 可是在我添加之后产生了如下问题
问题 :
问题是在访问数据集中单张图片时产生错误
下面为错误代码 :
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data("C:/Users/你的用户名/.keras/datasets/mnist.npz") # 导入数据集
# X_train = train_X.reshape((60000, 28*28))
# X_test = test_X.reshape((10000, 28*28)) # 变换输入维度,这里使用 tf.keras,layers.Flatten()
X_train, X_test = tf.cast(train_x/255.0, tf.float32), tf.cast(test_x/255.0, tf.float32)
y_train, y_test = tf.cast(train_y/255.0, tf.float32), tf.cast(test_y/255.0, tf.float32) # 为加快迭代速度,属性归一化,且转换为tensor张量
# 建立模型
model = tf.keras.Sequential() # 选择模型
model.add(tf.keras.layers.Flatten(input_shape = (28, 28))) # 说明输入层形状,转换形状,拉直变成一维数组,line 12, 13
model.add(tf.keras.layers.Dense(128, activation = 'relu')) # 添加 Hidden 层,128 神经元, relu 激活函数
model.add(tf.keras.layers.Dense(10, activation = 'softmax')) # 添加 Output 层, 10 个类别神经元, softmax 激活函数,多分类
# model.summary() # 查看神经网络信息
# 配置模型训练方法
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics = 'sparse_categorical_accuracy') # 选择优化梯度算法 adam, 稀疏交叉熵损失函数, 稀疏分类准确率函数
# 训练模型
model.fit(X_train, y_train, batch_size = 64, epochs = 5, validation_split = 0.2) # 划分 20% 作为测试集 64数据每批次,一共训练 5 轮
# 评估模型
model.evaluate(X_test, y_test, verbose = 2) # 使用本身测试集评估
# 使用模型
model.predict([[X_test[0]]]) # predict 中,参数维数和训练集一致,所以给出归一化后属性值,训练集属性为 3 维数组,需要 2 层方括号
np.argmax(model.predict([[X_test[0]]])) # 使用 argmax 得到最大概率索引值
为此我猜测可能是维度问题,因为在model.predict(X_test[0:4])
中,并未产生问题,于是修改了部分代码,如下:
# 使用模型
demo = tf.constant((test_x[0]/255.0).reshape(1, 28, 28), dtype=tf.float32) # 变换一下维度先归一化后 reshape 成 3 维再转换Tensor
model.predict(demo) # predict 中,参数维数和训练集一致,所以给出归一化后属性值,训练集属性为 3 维数组,需要 2 层方括号
np.argmax(model.predict(demo))
先将测试集第一张图片归一化后 reshape 为 3 维形状,再转换成 Tensor 张量,发现可以使用,成功输出了 7
但是在成功输出的同时,伴随有以下警告 :
版本信息为 :
暂时先问以下这个问题吧,希望看到的大佬可以评论留下解决方法,小白不胜感激!!