pix2pix损失函数理解(精)

pix2pix损失函数理解(精)

pix2pix损失函数理解(精) 

pix2pix损失函数理解(精) 

pix2pix损失函数理解(精) 

 下面分为生成器和鉴别器的损失函数分别进行说明:

1.生成器(generator)的损失函数:生成器的损失函数由对抗损失和像素损失构成。

pix2pix损失函数理解(精)

    def backward_G(self):
        """Calculate GAN and L1 loss for the generator"""
        # 1.对抗损失,G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)
        # 2.像素损失,G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
        # combine loss and calculate gradients
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

 (2)判别器的损失函数: pix2pix中判别器的损失与cGAN相同。

pix2pix损失函数理解(精)

    def backward_D(self):
        """Calculate GAN loss for the discriminator"""
        # Fake; 后半部分,stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)
        # Real:前半部分
        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)
        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()

 Pix2pix-两个领域匹配图像的转换 - 简书

上一篇:solve the problem of fake coins


下一篇:【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络