tensorflow.keras搭建gan神经网络,可直接运行
文章目录
前言
keras是tensorflow的一个高级API库之一,代码简洁,可读性强。本文采用tensorflow.keras来实现gan网络。具体的原理在本文不作过多阐述,只作为一个案例交流正文
一、tf.keras搭建gan网络大致步骤
1.首先我们需要将所有的图像数据装换为tensorflow提供的tfrecords的格式,利用creat_tfrecords.py文件生成即可(这个文件是我原来用作图像分类的标签生成的脚本文件,如果做gan网络不需要将标签也保存)
2.利用生成的tfrecords文件来建立数据集,利用tf.data.TFRecordDataset来进行设置,本文还提供了另一种方法来对tfrecords数据进行获取,但是殊途同归,方法都差不多
3.搭建generator网络
4.搭建discriminator网络,整合为gan网络(需要在gan网络compile之前将discriminator网络设置为不可训练)
5.建立循环体分别训练generator网络和discriminator网络
6.保存网络gan.model
二、使用步骤
1.制作tfrecords数据集
creat_tfrecords.py
默认生成tfrecords位置为 filename_train="./data/train.tfrecords"
终端输入:python creat_tfrecords.py --data [数据集位置]
生成train.tfrecords,也可以自己动手添加验证集和测试集的数据
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import random
objects = ['cat','dog']#'cat'0,'dog'1
filename_train="./data/train.tfrecords"
writer_train= tf.python_io.TFRecordWriter(filename_train)
tf.app.flags.DEFINE_string(
'data', 'None', 'where the datas?.')
FLAGS = tf.app.flags.FLAGS
if(FLAGS.data == None):
os._exit(0)
dim = (224,224)
object_path = FLAGS.data
total = os.listdir(object_path)
for index in total:
img_path=os.path.join(object_path,index)
img=Image.open(img_path)
img=img.resize(dim)
img_raw=img.tobytes()
for i in range(len(objects)):
if objects[i] in index:
value = i
else:
continue
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[value])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
print([index,value])
writer_train.write(example.SerializeToString()) #序列化为字符串
writer_train.close()
2.读入数据
利用tf.data.TFRecordDataset建立
代码如下:(load_image函数用来作为map的输入,对数据集进行解码),在main函数中调用:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
def load_image(serialized_example):
features={
'label': tf.io.FixedLenFeature([], tf.int64),
'img_raw' : tf.io.FixedLenFeature([], tf.string)}
parsed_example = tf.io.parse_example(serialized_example,features)
image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
image = tf.reshape(image,[-1,224,224,3])
image = tf.cast(image,tf.float32)*(1./255)
label = tf.cast(parsed_example['label'], tf.int32)
label = tf.reshape(label,[-1,1])
return image,label
def dataset_tfrecords(tfrecords_path,use_keras_fit=True):
#是否使用tf.keras
if use_keras_fit:
epochs_data = 1
else:
epochs_data = epochs
dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
dataset = dataset\
.repeat(epochs_data)\
.batch(batch_size)\
.map(load_image,num_parallel_calls = 2)\
.shuffle(1000)
iter = dataset.make_one_shot_iterator()#make_initialization_iterator
train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
return train_datas,iter
3.搭建gan网络
a.搭建generator网络
generator = keras.models.Sequential([
#fullyconnected nets
keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
keras.layers.Dense(64,activation='selu'),
keras.layers.Dense(256,activation='selu'),
keras.layers.Dense(1024,activation='selu'),
keras.layers.Dense(7*7*64,activation='selu'),
keras.layers.Reshape([7,7,64]),
#7*7*64
#反卷积
keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
#14*14*64
keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
#28*28*64
keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
#56*56*32
keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
#112*112*16
keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
#224*224*3
keras.layers.Reshape([224,224,3])
])
b.搭建discriminator网络
discriminator = keras.models.Sequential([
keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
keras.layers.MaxPool2D(pool_size=2),
#56*56*128
keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
keras.layers.MaxPool2D(pool_size=2),
#14*14*64
keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
#7*7*32
keras.layers.Flatten(),
#dropout 0.4
keras.layers.Dropout(0.4),
keras.layers.Dense(512,activation='selu'),
keras.layers.Dropout(0.4),
keras.layers.Dense(64,activation='selu'),
keras.layers.Dropout(0.4),
#the last net
keras.layers.Dense(1,activation='sigmoid')
])
c.整合generator,discriminator网络为gan网络
gan = keras.models.Sequential([generator,discriminator])
4.complie编译(建立loss和optimizer优化器)
#compile the net
discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
discriminator.trainable=False
gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
5.训练网络(建立循环)
获取数据集:
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
循环体:(在里面使用cv2来对generator网络查看)
generator,discriminator = gan.layers
sess = tf.Session()
for step in range(num_steps):
#get the time
start_time = time.time()
#phase 1 - training the discriminator
noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
noise = np.cast[np.float32](noise)
generated_images = generator.predict(noise)
train_datas_ = sess.run(train_datas)
x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
#千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
#否则内存会被耗尽,而且训练速度越来越慢
y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
discriminator.trainable = True
dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
#将keras 的train_on_batch函数放在gan网络中是明智之举
#phase 2 - training the generator
noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
noise = np.cast[np.float32](noise)
y2 = np.array([[1.]]*batch_size)
discriminator.trainable = False
ad_loss = gan.train_on_batch(noise,y2)
duration = time.time()-start_time
if step % 5 == 0:
#gan.save_weights('gan.h5')
print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
print('%.2f s/step'%(duration))
if step % 30 == 0 and step != 0:
noise = np.random.normal(size=[1,coding_size])
noise = np.cast[np.float32](noise)
fake_image = generator.predict(noise,steps=1)
#复原图像
#1.乘以255后需要映射成uint8的类型
#2.也可以保持[0,1]的float32类型,依然可以直接输出
arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
arr_img = np.cast[np.uint8](arr_img)
#保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
cv2.imshow('fake image',arr_img)
cv2.waitKey(1500)#show the fake image 1.5s
cv2.destroyAllWindows()
6.保存网络
#save the models
model_vision = '0001'
model_name = 'gans'
model_path = os.path.join(model_name,model_name)
tf.saved_model.save(gan,model_path)
7.完整的gans.py(可运行)
# -*- coding: utf-8 -*-
'''
@author:zyl
author is zouyuelin
a Master of Tianjin University(TJU)
'''
import tensorflow as tf
from tensorflow import keras
#tf.enable_eager_execution()
import numpy as np
from PIL import Image
import os
import cv2
import time
batch_size = 32
epochs = 120
num_steps = 2000
coding_size = 30
tfrecords_path = 'data/train.tfrecords'
#--------------------------------------datasetTfrecord----------------
def load_image(serialized_example):
features={
'label': tf.io.FixedLenFeature([], tf.int64),
'img_raw' : tf.io.FixedLenFeature([], tf.string)}
parsed_example = tf.io.parse_example(serialized_example,features)
image = tf.decode_raw(parsed_example['img_raw'],tf.uint8)
image = tf.reshape(image,[-1,224,224,3])
image = tf.cast(image,tf.float32)*(1./255)
label = tf.cast(parsed_example['label'], tf.int32)
label = tf.reshape(label,[-1,1])
return image,label
def dataset_tfrecords(tfrecords_path,use_keras_fit=True):
#是否使用tf.keras
if use_keras_fit:
epochs_data = 1
else:
epochs_data = epochs
dataset = tf.data.TFRecordDataset([tfrecords_path])#这个可以有多个组成[tfrecords_name1,tfrecords_name2,...],可以用os.listdir(tfrecords_path):
dataset = dataset\
.repeat(epochs_data)\
.batch(batch_size)\
.map(load_image,num_parallel_calls = 2)\
.shuffle(1000)
iter = dataset.make_one_shot_iterator()#make_initialization_iterator
train_datas = iter.get_next() #用train_datas[0],[1]的方式得到值
return train_datas,iter
#------------------------------------tf.TFRecordReader-----------------
def read_and_decode(tfrecords_path):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([tfrecords_path],shuffle=True)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string)})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image,[224,224,3])#reshape 200*200*3
image = tf.cast(image,tf.float32)*(1./255)#image张量可以除以255,*(1./255)
label = tf.cast(features['label'], tf.int32)
img_batch, label_batch = tf.train.shuffle_batch([image,label],
batch_size=batch_size,
num_threads=4,
capacity= 640,
min_after_dequeue=5)
return [img_batch,label_batch]
#Autodecode 解码器
def autoencode():
encoder = keras.models.Sequential([
keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
#112*112*32
keras.layers.MaxPool2D(pool_size=2),
#56*56*32
keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
#28*28*64
keras.layers.MaxPool2D(pool_size=2),
#14*14*64
keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu'),
#7*7*128
#反卷积
keras.layers.Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',activation='selu'),
#14*14*128
keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
#28*28*64
keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
#56*56*32
keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
#112*112*16
keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
#224*224*3
keras.layers.Reshape([224,224,3])
])
return encoder
def training_keras():
'''
卷积和池化输出公式:
output_size = (input_size-kernel_size+2*padding)/strides+1
keras的反卷积输出计算,一般不用out_padding
1.若padding = 'valid':
output_size = (input_size - 1)*strides + kernel_size
2.若padding = 'same:
output_size = input_size * strides
'''
generator = keras.models.Sequential([
#fullyconnected nets
keras.layers.Dense(256,activation='selu',input_shape=[coding_size]),
keras.layers.Dense(64,activation='selu'),
keras.layers.Dense(256,activation='selu'),
keras.layers.Dense(1024,activation='selu'),
keras.layers.Dense(7*7*64,activation='selu'),
keras.layers.Reshape([7,7,64]),
#7*7*64
#反卷积
keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
#14*14*64
keras.layers.Conv2DTranspose(64,kernel_size=3,strides=2,padding='same',activation='selu'),
#28*28*64
keras.layers.Conv2DTranspose(32,kernel_size=3,strides=2,padding='same',activation='selu'),
#56*56*32
keras.layers.Conv2DTranspose(16,kernel_size=3,strides=2,padding='same',activation='selu'),
#112*112*16
keras.layers.Conv2DTranspose(3,kernel_size=3,strides=2,padding='same',activation='tanh'),#使用tanh代替sigmoid
#224*224*3
keras.layers.Reshape([224,224,3])
])
discriminator = keras.models.Sequential([
keras.layers.Conv2D(128,kernel_size=3,padding='same',strides=2,activation='selu',input_shape=[224,224,3]),
keras.layers.MaxPool2D(pool_size=2),
#56*56*128
keras.layers.Conv2D(64,kernel_size=3,padding='same',strides=2,activation='selu'),
keras.layers.MaxPool2D(pool_size=2),
#14*14*64
keras.layers.Conv2D(32,kernel_size=3,padding='same',strides=2,activation='selu'),
#7*7*32
keras.layers.Flatten(),
#dropout 0.4
keras.layers.Dropout(0.4),
keras.layers.Dense(512,activation='selu'),
keras.layers.Dropout(0.4),
keras.layers.Dense(64,activation='selu'),
keras.layers.Dropout(0.4),
#the last net
keras.layers.Dense(1,activation='sigmoid')
])
#gans network
gan = keras.models.Sequential([generator,discriminator])
#compile the net
discriminator.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
discriminator.trainable=False
gan.compile(loss="binary_crossentropy",optimizer='rmsprop')# metrics=['accuracy'])
#dataset
#train_datas = read_and_decode(tfrecords_path)
train_datas,iter = dataset_tfrecords(tfrecords_path,use_keras_fit=False)
#sess = tf.Session()
#sess.run(iter.initializer)
generator,discriminator = gan.layers
print("-----------------start---------------")
sess = tf.Session()
for step in range(num_steps):
#get the time
start_time = time.time()
#phase 1 - training the discriminator
noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
noise = np.cast[np.float32](noise)
generated_images = generator.predict(noise)
train_datas_ = sess.run(train_datas)
x_fake_and_real = np.concatenate([generated_images,train_datas_[0]],axis = 0)#np.concatenate
#千万不能再循环体内用tf.concat,不能用tf相关的函数在循环体内定义
#否则内存会被耗尽,而且训练速度越来越慢
y1 = np.array([[0.]]*batch_size+[[1.]]*batch_size)
discriminator.trainable = True
dis_loss = discriminator.train_on_batch(x_fake_and_real,y1)
#将keras 的train_on_batch函数放在gan网络中是明智之举
#phase 2 - training the generator
noise = np.random.normal(size=batch_size*coding_size).reshape([batch_size,coding_size])
noise = np.cast[np.float32](noise)
y2 = np.array([[1.]]*batch_size)
discriminator.trainable = False
ad_loss = gan.train_on_batch(noise,y2)
duration = time.time()-start_time
if step % 5 == 0:
#gan.save_weights('gan.h5')
print("The step is %d,discriminator loss:%.3f,adversarial loss:%.3f"%(step,dis_loss,ad_loss),end=' ')
print('%.2f s/step'%(duration))
if step % 30 == 0 and step != 0:
noise = np.random.normal(size=[1,coding_size])
noise = np.cast[np.float32](noise)
fake_image = generator.predict(noise,steps=1)
#复原图像
#1.乘以255后需要映射成uint8的类型
#2.也可以保持[0,1]的float32类型,依然可以直接输出
arr_img = np.array([fake_image],np.float32).reshape([224,224,3])*255
arr_img = np.cast[np.uint8](arr_img)
#保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGR
arr_img = cv2.cvtColor(arr_img,cv2.COLOR_RGB2BGR)
cv2.imshow('fake image',arr_img)
cv2.waitKey(1500)#show the fake image 1.5s
cv2.destroyAllWindows()
#save the models
model_vision = '0001'
model_name = 'gans'
model_path = os.path.join(model_name,model_name)
tf.saved_model.save(gan,model_path)
def main():
training_keras()
main()
至此便完成了简单的gan训练
参考资料
论文:《Generative Adversarial Networks》
参考源码:
https://github.com/eriklindernoren/Keras-GAN/blob/master/gan/gan.py
参考博客:
https://blog.csdn.net/u010138055/article/details/94441812
最后的话
深度学习、机器学习的学渣小硕一枚,刚起步,不足的地方还请大家多多指教。