tensorflow2.0—— GAN实战代码

from  tensorflow import keras
import tensorflow as tf
from  tensorflow.keras import layers
import numpy as np
import os
import matplotlib.pyplot as plt

#   设置相关底层配置
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # 使用第2块gpu

#   拼接图片
def my_save_img(data,save_path):
    #   新图拼接行列
    r_c = 10
    len_data = data.shape[0]
    each_pix = 64
    save_img_path = save_path
    new_img = np.zeros((r_c*each_pix,r_c*each_pix,3))
    for index,each_img in enumerate(data[:r_c*r_c]):
        # print('each_img.shape:',each_img.shape,np.max(each_img),np.min(each_img))
        each_img  = (each_img+1)/2
        # print('each_img.shape:', each_img.shape, np.max(each_img), np.min(each_img))
        row_start = int(index/r_c) * each_pix
        col_start = (index%r_c)*each_pix
        # print(index,row_start,col_start)
        new_img[row_start:row_start+each_pix,col_start:col_start+each_pix,:] = each_img
        # print('new_img:',new_img)

    plt.imsave(save_img_path,new_img)

class Generator(keras.Model):
    def __init__(self):
        super(Generator,self).__init__()
        # z: [b, 100] => [b, 3*3*512] => [b, 3, 3, 512] => [b, 64, 64, 3]
        self.fc = layers.Dense(3 * 3 * 512)

        self.Tconv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        self.Tconv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.Tconv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    def call(self, inputs, training=None, mask=None):
        # [z, 100] => [z, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        x = tf.nn.leaky_relu(x)

        #
        x = tf.nn.leaky_relu(self.bn1(self.Tconv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.Tconv2(x), training=training))
        x = self.Tconv3(x)
        x = tf.tanh(x)

        return x

class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator,self).__init__()
        # [b, 64, 64, 3] => [b, 1]

        self.conv1 = layers.Conv2D(64,5,3,'valid')

        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()

        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        #   [b,h,w,3] => [b,-1]
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    def call(self, inputs, training=None, mask=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x),training = training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        #   打平
        x = self.flatten(x)
        #   [b,-1] => [b,1]
        logits = self.fc(x)
        return logits

def main():
    #   超参数
    z_dim = 100
    epochs = 3000000
    batch_size = 1024
    learning_rate = 0.002
    is_training = True

    img_data = np.load('img.npy')
    train_db = tf.data.Dataset.from_tensor_slices(img_data).shuffle(10000).batch(batch_size)
    sample = next(iter(train_db))
    print(sample.shape, tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())

    train_db = train_db.repeat()
    db_iter = iter(train_db)

    #   判别器
    d = Discriminator()
    # d.build(input_shape=(None, 64, 64, 3))
    #   生成器
    g = Generator()
    # g.build(input_shape=(None, z_dim))

    #   分别定义优化器
    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epochs):
        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            # 1. treat real image as real
            # 2. treat generated image as fake
            fake_image = g(batch_z, is_training)
            d_fake_logits = d(fake_image, is_training)
            d_real_logits = d(batch_x, is_training)

            d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real_logits,labels=tf.ones_like(d_real_logits))
            # d_loss_real = tf.reduce_mean(d_loss_real)
            d_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
            # d_loss_fake = tf.reduce_mean(d_loss_fake)

            d_loss = d_loss_fake + d_loss_real
        grads = tape.gradient(d_loss, d.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, d.trainable_variables))

        with tf.GradientTape() as tape:
            fake_image = g(batch_z, is_training)
            d_fake_logits = d(fake_image, is_training)
            g_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake_logits,labels=tf.ones_like(d_fake_logits))
            # g_loss = tf.reduce_mean(g_loss)
        grads = tape.gradient(g_loss, g.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, g.trainable_variables))
        if epoch % 10 == 0:
            # print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
            print(epoch, 'd-loss:', d_loss.numpy(), 'g-loss:', g_loss.numpy())
            if epoch % 50 == 0:
                z = tf.random.uniform([225,z_dim])
                fake_image = g(z,training = False)
                img_path = os.path.join('g_pic2', 'gan-%d.png'%epoch)
                my_save_img(fake_image,img_path)


if __name__ == '__main__':
    main()

 

上一篇:GAN的实例代码——以double moon dataset为例


下一篇:第六节 图片风格迁移和GAN