首先导入必要的模块
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import os
然后设置一些超参数
z_dim = 100 #噪声维度
batch_size = 64
learning_rate = 0.0002
total_epochs = 100
如果你的计算机有GPU,可以指定使用哪块GPU
gpu_id ='0'
if gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
device = torch.device('cuda')
else:
device = torch.device('cpu')
创建路径,存放每次迭代生成的图片
if os.path.exists('gan_images') is False:
os.makedirs('gan_images')
以上都是一些准备工作,接下来开始定义模型,首先是判别器Discriminator:
class Discriminator(nn.Module):
'''定义全连接判别器'''
def __init__(self):
super(Discriminator,self).__init__()
layers =[]
#first floor
layers.append(nn.Linear(in_features=28*28,out_features=512,bias=True))
layers.append(nn.LeakyReLU(0.2,inplace=True))
#second floor
layers.append(nn.Linear(in_features=512, out_features=256, bias=True))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# outpur floor
layers.append(nn.Linear(in_features=256, out_features=1, bias=True))
layers.append(nn.Sigmoid())
self.model = nn.Sequential(*layers)
def forward(self, x):
x = x.view(x.size(0),-1)
validity = self.model(x)
return validity
首先创建一个空列表layers,存放判别器的全连接层。由于我使用的是MNIST数据集,因此第一层的输入为28*28,输出特征设置为512,并添加激活层LeakyReLU