pytorch(二十六):自动编码器

一、自动编码器

1、AE.py

import torch
from torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        #[b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 20),
            nn.ReLU()
        )

        #[b, 20] => [b, 784]
        self.decoder = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid(),
        )

    def forward(self, x):
        """
        :param x: [b, 1, 28, 28]
        :return:
        """
        batchsz = x.shape[0]
        #flatten
        x = x.view(batchsz, 784)
        #encoder
        x = self.encoder(x)
        #decoder
        x = self.decoder(x)
        #reshape
        x = x.view(batchsz,1, 28, 28)

        return x, None

2、main.py

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from auto_encoder import AE
from torch import nn, optim
import visdom
def main():
    mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).__next__()
    print(x.shape)

    model = AE()
    criton = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    viz = visdom.Visdom()
    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            #[b, 1, 28, 28]
            x_hat, _ = model(x)
            loss = criton(x_hat, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, "loss:",  loss.item())
        x, _ = iter(mnist_test).__next__()
        with torch.no_grad():
            x_hat, _ = model(x)
        viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
        viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))

if __name__ == '__main__':
    main()

 

上一篇:线性表(linear_list)


下一篇:使用torch加载模型时出现字典键对应不起来的问题