import tensorflow as tf from tensorflow import keras import numpy as np import matplotlib.pyplot as plt from PIL import Image import os from tensorflow.keras import Sequential, layers import sys # # 设置相关底层配置 physical_devices = tf.config.experimental.list_physical_devices('GPU') assert len(physical_devices) > 0, "Not enough GPU hardware devices available" tf.config.experimental.set_memory_growth(physical_devices[0], True) # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 使用第2块gpu # 超参数 h_dim = 20 batchsz = 512 learn_rate = 1e-3 (x_train, _), (x_test, _) = keras.datasets.fashion_mnist.load_data() x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255. print('x_train.shape:', x_train.shape) train_db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batchsz * 5).batch(batchsz) test_db = tf.data.Dataset.from_tensor_slices(x_test).batch(batchsz) def my_save_img(data,name): save_img_path = './img_dir/AE_img/{}.jpg'.format(name) new_img = np.zeros((280,280)) for index,each_img in enumerate(data[:100]): row_start = int(index/10) * 28 col_start = (index%10)*28 # print(index,row_start,col_start) new_img[row_start:row_start+28,col_start:col_start+28] = each_img plt.imsave(save_img_path,new_img) # plt.imshow(new_img) # plt.show() # sys.exit(2) # 打印数据图 # for i in range(16): # plt.subplot(4,4,i+1) # plt.imshow(np.reshape(x_train[i],(28,28,1))) # plt.show() class AE(keras.Model): def __init__(self): super(AE, self).__init__() # Encoders self.encoder = Sequential([ layers.Dense(256, activation=tf.nn.relu), layers.Dense(128, activation=tf.nn.relu), layers.Dense(h_dim) ]) # Decoders self.decoder = Sequential([ layers.Dense(128, activation=tf.nn.relu), layers.Dense(256, activation=tf.nn.relu), layers.Dense(28 * 28), ]) def call(self, inputs, training=None, mask=None): # [b,784] => [b,h_dim] x = self.encoder(inputs) # [b,h_dim] => [b,784] x = self.decoder(x) return x my_model = AE() my_model.build(input_shape=(None, 784)) my_model.summary() opt = tf.optimizers.Adam(lr=learn_rate) for epoch in range(50): for step, x in enumerate(train_db): # [b,28,28] => [b,784] 打平 x = tf.reshape(x, [-1, 784]) with tf.GradientTape() as tape: out = my_model(x) my_loss = tf.losses.binary_crossentropy(x, out, from_logits=True) # my_loss = tf.losses.mean_squared_error(x,out) my_loss = tf.reduce_mean(my_loss) grads = tape.gradient(my_loss, my_model.trainable_variables) opt.apply_gradients(zip(grads, my_model.trainable_variables)) if step % 100 == 0: print(epoch,step,float(my_loss)) # evaluation x = next(iter(test_db)) my_save_img(x, '{}_label'.format(epoch)) x = tf.reshape(x, [-1, 784]) logits = my_model(x) x_hat = tf.sigmoid(logits) # loss用binary # x_hat = logits # loss用MSE x_hat = tf.reshape(x_hat,[-1,28,28]) my_save_img(x_hat,'{}_pre'.format(epoch))