一、自动编码器
1、AE.py
import torch from torch import nn class AE(nn.Module): def __init__(self): super(AE, self).__init__() #[b, 784] => [b, 20] self.encoder = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 64), nn.ReLU(), nn.Linear(64, 20), nn.ReLU() ) #[b, 20] => [b, 784] self.decoder = nn.Sequential( nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid(), ) def forward(self, x): """ :param x: [b, 1, 28, 28] :return: """ batchsz = x.shape[0] #flatten x = x.view(batchsz, 784) #encoder x = self.encoder(x) #decoder x = self.decoder(x) #reshape x = x.view(batchsz,1, 28, 28) return x, None
2、main.py
import torch from torch.utils.data import DataLoader from torchvision import transforms, datasets from auto_encoder import AE from torch import nn, optim import visdom def main(): mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True) mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([ transforms.ToTensor() ]), download=True) mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True) x, _ = iter(mnist_train).__next__() print(x.shape) model = AE() criton = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) viz = visdom.Visdom() for epoch in range(1000): for batchidx, (x, _) in enumerate(mnist_train): #[b, 1, 28, 28] x_hat, _ = model(x) loss = criton(x_hat, x) optimizer.zero_grad() loss.backward() optimizer.step() print(epoch, "loss:", loss.item()) x, _ = iter(mnist_test).__next__() with torch.no_grad(): x_hat, _ = model(x) viz.images(x, nrow=8, win="x", opts=dict(title = "x")) viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat")) if __name__ == '__main__': main()