TensorFlow学习记录(九)---人工神经网络经典框架

TensorFlow学习记录(九)---人工神经网络经典框架

 TensorFlow学习记录(九)---人工神经网络经典框架

这种格式也可以,但不清晰

TensorFlow学习记录(九)---人工神经网络经典框架

 TensorFlow学习记录(九)---人工神经网络经典框架

 

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
#加载数据
mnist = tf.keras.datasets.mnist
(train_x,train_y),(test_x,test_y) = mnist.load_data()
x_train ,x_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)
y_train,y_test = tf.cast (train_y,tf.int16),tf.cast(test_y,tf.int16)
#建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(28,28)))
model.add(tf.keras.layers.Dense(128,activation='relu'))
model.add(tf.keras.layers.Dense(10,activation='softmax'))
model.summary()
#配置模型的训练方法
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])
#训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5,validation_split=0.2)
#评估模型
model.evaluate(x_test,y_test,verbose=2  )
#使用模型
plt.axis('off')
plt.imshow(test_x[0],cmap='gray')
plt.show()
print(y_test[0])
model.predict([[x_test[0]]])
print(np.argmax(model.predict([[x_test[0]]])))

上一篇:tensorflow鸢尾花分类


下一篇:TensorFlow学习记录(四)---TensorFlow自动求导机制