VAE网络结构较AE只有部分改变
import torch
import numpy as np
from torch import nn
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# [b,784] => [b,20]
# u: [b,10]
# sigma: [b,10]
self.encoder = nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,64),
nn.ReLU(),
nn.Linear(64,20),
nn.ReLU()
)
# [b,20] => [b,784]
self.decoder = nn.Sequential(
nn.Linear(10,64),
nn.ReLU(),
nn.Linear(64,256),
nn.ReLU(),
nn.Linear(256,784),
nn.Sigmoid() # 压缩到0-1
)
def forward(self,x):
"""
:param x:[b,1,28,28]
:return:
"""
batchsz = x.size(0)
# flatten
x=x.view(batchsz,784)
# encoder
# [b,20] 包含了mean 和 sigma
h_ = self.encoder(x)
# VAE 和 AE 的不同之处
# 把mu和sigma拆分出来,用chunk(拆分的个数,位置)
# [b,20]-——>[b,10] and [b,10]
mu,sigma = h_.chunk(2,dim=1)
# reparametrize trick ,epison~N(0,1)
h = mu + sigma * torch.randn_like(sigma) # 后边这个是sigma的正态分布
# decoder
x_hat = self.decoder(h)
# reshape 因为是打平过的,还需要再变回照片
x_hat = x_hat.view(batchsz,1,28,28)
# 计算KL divergence,网上可以查它的公式,这里u2=0,sigma2=1
kld = 0.5 * torch.sum(
torch.pow(mu,2)+
torch.pow(sigma,2)-
torch.log(1e-8 + torch.pow(sigma,2))-1
) / (batchsz*28*28)
return x_hat , kld
只是多个这个kld
主函数部分变化有
import torch
from torchvision import transforms,datasets # datasets自带数据集MNIST
from torch.utils.data import DataLoader
from AE import AE
from VAE import VAE
from torch import nn,optim
import visdom
def main():
# 把MNIST数据集加载进来
mnist_train = datasets.MNIST('mnist',True,transform=transforms.Compose([
transforms.ToTensor()
]),download=True)
# 把数据集加载到DataLoader中
mnist_train = DataLoader(mnist_train,batch_size=32,shuffle=True)
# 把MNIST数据集加载进来
mnist_test = datasets.MNIST('mnist',True,transform=transforms.Compose([
transforms.ToTensor()
]),download=True)
# 把数据集加载到DataLoader中
mnist_test = DataLoader(mnist_test,batch_size=32,shuffle=True)
# 构建一个迭代器
x,_ = iter(mnist_train).next() # 不返回label,因为这是无监督学习
print('x:',x.shape) # x:torch.Size([32, 1, 28, 28])
device = torch.device('cuda')
model = VAE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=1e-3)
print(model)
viz = visdom.Visdom()
for epoch in range(1000):
for batchidx,(x,_) in enumerate(mnist_train):
# [b,1,28,28]
x=x.to(device)
# x_hat表示重建过的x
x_hat,kld = model(x)
loss = criteon(x_hat,x)
# VAE才有的
if kld is not None:
elbo = -loss - 1.0 * kld
loss = -elbo
# backprop
optimizer.zero_grad() #第一步梯度清零
loss.backward() # 第二步backward
optimizer.step() # 第三步更新梯度
print(epoch,'loss',loss.item(),'kld',kld.item())
# 从test中取一些图片进行重构
x, _ = iter(mnist_test).next()
x=x.to(device)
with torch.no_grad():
x_hat,kld = model(x)
# 可视化
viz.images(x,nrow=8,win='x',opts=dict(title='x'))
viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))
if __name__ == '__main__':
main()
得到结果
具体原因可能是任务太简单了,体现不出VAE的好用。