GAN动漫人物头像生成
1.简介
搭建了一个简单的DCGAN网络生成动漫人物的头像,其中动漫人物头像数据集取自kaggle,网址如下
link
2.网络结构
- 数据集
- 生成器
- 判别器
2.1数据集
数据大小为64x64x3,样例如下
2.2生成器
由于生成器的原始输入是n维噪声,若想生成与数据集大小相同的图片,则需要进行上采样,这里我们用到的方法是转置卷积,通过pytorch中的ConvTransposed2d来实现。
生成器代码如下:
class Generator(nn.Module):
def __init__(self, noise_dim=100):
super(Generator, self).__init__()
self.net = nn.Sequential(
# out_shape = (1-1)*1-2*0+4 = 4*4
nn.ConvTranspose2d(noise_dim, 256, kernel_size=4),
nn.BatchNorm2d(256),
nn.ReLU(),
# out_shape = (4-1)*2-2*1+4 = 8*8
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
# out_shape = (8-1)*2-2*1+4 = 16*16
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
# out_shape = (16-1)*2-2*1+4 = 32*32
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
# out_shape = (32-1)*2-2*1+4 = 64*64
nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, input):
output = self.net(input)
return output
在训练阶段,我们会生成batch_sizex100x1x1大小的随机噪声,然后经过生成器的上采样,实现与数据集图片大小相同的伪图片,然后送到判别器中进行真假图片的辨别。
2.3判别器
判别器的输入为从数据集中采样的真实图片与生成器生成的伪图片,输出为0-1之间的数值,因此网络尾端使用了Sigmoid激活函数。
判别器的目的是对真实图片判别为“1”(真),对伪图片判别为“0”(假),而生成器的目的是生成的伪图片足够好,足够逼近数据集的分布,以此骗过生成器,因此生成器希望自己生成的伪图片在判别器中得分越接近“1”(真)越好。这样判别器与生成器不断“对抗”,最后达到平衡或接*衡。
判别器代码如下:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# 32*32*32
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
# 16*16*64
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
# 8*8*128
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
# 4*4*256
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(4*4*256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, input):
output = self.net(input)
return output.view(-1)
判别器的网络就是简单的前馈神经网络,经过卷积不断的下采样,提取图片的特征,最后输出为真或为假的0-1之间的得分。
3.训练阶段
训练阶段的大体流程跟深度学习训练流程相差无几,最重要的部分是label与损失函数的设计与计算。
先贴上训练阶段的代码:
import torch
import torch.nn as nn
from torchvision import transforms
from create_dataset import My_dataset, save_img
from torch.utils.data import DataLoader
from net import Generator, Discriminator
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = My_dataset('./data', transform=transform)
batch_size, epochs = 256, 200
my_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():
discriminator = discriminator.cuda()
generator = generator.cuda()
d_optimizer = torch.optim.Adam(discriminator.parameters(), betas=(0.5, 0.99), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), betas=(0.5, 0.99), lr=1e-4)
criterion = nn.BCELoss()
for epoch in range(epochs):
for i, img in enumerate(my_dataloader):
noise = torch.randn(batch_size, 100, 1, 1).cuda()
real_img = img.cuda()
fake_img = generator(noise)
real_label = torch.ones(batch_size).cuda()
fake_label = torch.zeros(batch_size).cuda()
real_out = discriminator(real_img)
fake_out = discriminator(fake_img)
real_loss = criterion(real_out, real_label)
fake_loss = criterion(fake_out, fake_label)
d_loss = real_loss + fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
noise = torch.randn(batch_size, 100, 1, 1).cuda()
fake_img = generator(noise)
output = discriminator(fake_img)
g_loss = criterion(output, real_label)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i + 1) % 5 == 0:
print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
'D_real: {:.6f},D_fake: {:.6f}'.format(
epoch, epochs, d_loss.data.item(), g_loss.data.item(),
real_out.data.mean(), fake_out.data.mean() # 打印的是真实图片的损失均值
))
if epoch == 0 and i == len(my_dataloader) - 1:
save_img(img[:64, :, :, :], './sample/real_images.png')
if (epoch+1) % 10 == 0 and i == len(my_dataloader)-1:
save_img(fake_img[:64, :, :, :], './sample/fake_images_{}.png'.format(epoch + 1))
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')
在训练之前,首先要人为的设置图片为真、假的label,这里我们真设置为1,使用torch.ones函数实现,假设置为0,使用torch.zeros函数实现。
然后就是对数据集中的图片以及生成器生成的伪图片进行判别损失的计算,如代码中d_loss。
接下来是对生成器损失的计算,因为生成器的目的是生成的图片越真越好,所以生成器损失的计算的label是1。如代码中g_loss。
4.反归一化及结果
4.1反归一化
因为对数据集进行了归一化及标准化处理,所以在显示生成器结果时需要进行反归一化,在这里我先是使用了torchvision中的save_image去保存生成器的结果,但是该官方函数的反归一化与我们的归一化过程不符,导致该函数保存的图片有些暗,如下所示(下图为数据集中的真实图片):
所以这里另进行了数据的反归一化过程并使用torchvision中的make_grid函数进行保存,结果如下(为数据集中的真实图片):
4.2结果
训练200轮,每10轮保存一次结果,其中10,50,100,150,200轮的结果如下图所示:
可以看到,生成器生成的图片在逐步清晰,且越来越逼近数据集的分布
5.总结
最后效果还是不太好,GAN的训练过程也不是太稳定,尤其是如何让图片更加清晰,不模糊仍然是一个比较“棘手”的问题。
(新手小白第一次写博客,大神勿喷)
最后,全部代码可见我的github