下面分为生成器和鉴别器的损失函数分别进行说明:
1.生成器(generator)的损失函数:生成器的损失函数由对抗损失和像素损失构成。
def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
# 1.对抗损失,G(A) should fake the discriminator
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# 2.像素损失,G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
self.loss_G.backward()
(2)判别器的损失函数: pix2pix中判别器的损失与cGAN相同。
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; 后半部分,stop backprop to the generator by detaching fake_B
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real:前半部分
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()