生成对抗网络代码详解(一):GAN

首先导入必要的模块

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

import os

然后设置一些超参数

z_dim = 100 #噪声维度
batch_size = 64
learning_rate = 0.0002
total_epochs = 100

如果你的计算机有GPU,可以指定使用哪块GPU

gpu_id ='0'
if gpu_id is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

创建路径,存放每次迭代生成的图片

if os.path.exists('gan_images') is False:
    os.makedirs('gan_images')

以上都是一些准备工作,接下来开始定义模型,首先是判别器Discriminator:

class Discriminator(nn.Module):
    '''定义全连接判别器'''
    def __init__(self):
         super(Discriminator,self).__init__()

         layers =[]
         #first floor
         layers.append(nn.Linear(in_features=28*28,out_features=512,bias=True))
         layers.append(nn.LeakyReLU(0.2,inplace=True))
         #second floor
         layers.append(nn.Linear(in_features=512, out_features=256, bias=True))
         layers.append(nn.LeakyReLU(0.2, inplace=True))
         # outpur floor
         layers.append(nn.Linear(in_features=256, out_features=1, bias=True))
         layers.append(nn.Sigmoid())

         self.model = nn.Sequential(*layers)

    def forward(self, x):
         x = x.view(x.size(0),-1)
         validity = self.model(x)
         return validity

首先创建一个空列表layers,存放判别器的全连接层。由于我使用的是MNIST数据集,因此第一层的输入为28*28,输出特征设置为512,并添加激活层LeakyReLU

上一篇:Social GAN代码要点记录


下一篇:windows下引导双系统卸载Linux