通过 Keras 实现 GAN ,其主要过程如下:
正如上图所示,通过调节 Generator 和 Discriminator 交替训练来达到不断达到真实数据的拟合过程。这里 Generator 输入为随机数。
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
from tqdm import tqdm
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers
K.set_image_dim_ordering('th')
# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)
# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100
# Load MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)
# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)
# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)
dLosses = []
gLosses = []
# Plot the loss from each batch
def plotLoss(epoch):
plt.figure(figsize=(10, 8))
plt.plot(dLosses, label='Discriminitive loss')
plt.plot(gLosses, label='Generative loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('images/gan_loss_epoch_%d.png' % epoch)
# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
noise = np.random.normal(0, 1, size=[examples, randomDim])
generatedImages = generator.predict(noise)
generatedImages = generatedImages.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(generatedImages.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('images/gan_generated_image_epoch_%d.png' % epoch)
# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
generator.save('models/gan_generator_epoch_%d.h5' % epoch)
discriminator.save('models/gan_discriminator_epoch_%d.h5' % epoch)
def train(epochs=1, batchSize=128):
batchCount = int(X_train.shape[0] / batchSize)
print('Epochs:', epochs)
print('Batch size:', batchSize)
print('Batches per epoch:', batchCount)
for e in range(1, epochs+1):
print('-'*15, 'Epoch %d' % e, '-'*15)
for _ in tqdm(range(batchCount)):
# Get a random set of input noise and images
noise = np.random.normal(0, 1, size=[batchSize, randomDim])
imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]
# Generate fake MNIST images
# 用 Generator 生成假数据
generatedImages = generator.predict(noise)
# print np.shape(imageBatch), np.shape(generatedImages)
# 将假数据与真实数据进行混合在一起
X = np.concatenate([imageBatch, generatedImages])
# Labels for generated and real data
# 标记所有数据都是假数据
yDis = np.zeros(2*batchSize)
# One-sided label smoothing
# 按真实数据比例,标记前半数据为 0.9 的真实度
yDis[:batchSize] = 0.9
# Train discriminator
# 先训练 Discriminator 让其具有判定能力,同时Generator 也在训练,也能更新参数。
discriminator.trainable = True
dloss = discriminator.train_on_batch(X, yDis)
# Train generator
# 然后训练 Generator, 注意这里训练 Generator 时候,把 Generator 生成出来的结果置为全真,及按真实数据的方式来进行训练。
# 先生成相应 batchSize 样本 noise 数据
noise = np.random.normal(0, 1, size=[batchSize, randomDim])
# 生成相应的 Discriminator 输出结果
yGen = np.ones(batchSize)
# 将 Discriminator 设置为不可训练的状态
discriminator.trainable = False
# 训练整个 GAN 网络即可训练出一个能生成真实样本的 Generator
gloss = gan.train_on_batch(noise, yGen)
# Store loss of most recent batch from this epoch
dLosses.append(dloss)
gLosses.append(gloss)
if e == 1 or e % 20 == 0:
plotGeneratedImages(e)
saveModels(e)
# Plot losses from every epoch
plotLoss(e)
if __name__ == '__main__':
train(200, 128)
上述,通过 Sequential
创建 Generator
和 Discriminator
网络,然后通过 Model
将其进行结合在一起,从而构成了 GAN 网络的全貌。