PyTorch 第六章

文章目录


一 Neural Style Transfer

1图片风格迁移

PyTorch 第六章

2.图片表示

PyTorch 第六章

3.Content Loss

PyTorch 第六章

4.Style Loss

PyTorch 第六章

二 Generative Adversarial Networks

1 基本含义

PyTorch 第六章
PyTorch 第六章

2 DCGAN

PyTorch 第六章

三 CycleGAN

1. Network

PyTorch 第六章

2.模型架构

PyTorch 第六章

3.损失函数

PyTorch 第六章

四 代码实战

1.图片风格迁移 Neural Style Transfer

代码如下:

%matplotlib inline

from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_image(image_path, transform=None, max_size=None, shape=None):
    image = Image.open(image_path)
    if max_size:
        scale = max_size / max(image.size)
        size= np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
         
    if shape:
        image = image.resize(shape, Image.LANCZOS)
        
    if transform:
        image = transform(image).unsqueeze(0)
        
    return image.to(device)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
]) # 来自ImageNet的mean和variance

content = load_image("png/content.png", transform, max_size=400)
stype = load_image("png/style.png", transform, shape=[content.size(2), content.size(3)])

# content = load_image("png/content.png", transforms.Compose([
#     transforms.ToTensor(),
# ]), max_size=400)
# style = load_image("png/style.png", transforms.Compose([
#     transforms.ToTensor(),
# ]), shape=[content.size(2), content.size(3)])
stype.shape

输出:

torch.Size([1, 3, 400, 272])
unloader = transforms.ToPILImage()  # reconvert into PIL image

plt.ion()

def imshow(tensor, title=None):
    image = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    image = image.squeeze(0)      # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated


plt.figure()
imshow(style[0], title='Image')
# content.shape

输出:
PyTorch 第六章

class VGGNet(nn.Module):
    def __init__(self):
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features


target = content.clone().requires_grad_(True)
optimizer = torch.optim.Adam([target], lr=0.003, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()
target_features = vgg(target)
total_step = 2000
style_weight = 100.
for step in range(total_step):
    target_features = vgg(target)
    content_features = vgg(content)
    style_features = vgg(style)
    
    style_loss = 0
    content_loss = 0
    for f1, f2, f3 in zip(target_features, content_features, style_features):
        content_loss += torch.mean((f1-f2)**2)
        _, c, h, w = f1.size()
        f1 = f1.view(c, h*w)
        f3 = f3.view(c, h*w)
        
        # 计算gram matrix
        f1 = torch.mm(f1, f1.t())
        f3 = torch.mm(f3, f3.t())
        style_loss += torch.mean((f1-f3)**2)/(c*h*w)
        
    loss = content_loss + style_weight * style_loss
    
    # 更新target
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print("Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}"
             .format(step, total_step, content_loss.item(), style_loss.item()))  
Step [0/2000], Content Loss: 0.0000, Style Loss: 531.1730
Step [10/2000], Content Loss: 6.0654, Style Loss: 360.6187
Step [20/2000], Content Loss: 11.3430, Style Loss: 253.8006
Step [30/2000], Content Loss: 14.5195, Style Loss: 190.0798
Step [40/2000], Content Loss: 16.5578, Style Loss: 152.3939
Step [50/2000], Content Loss: 17.9683, Style Loss: 129.4922
Step [60/2000], Content Loss: 19.0225, Style Loss: 114.5218
Step [70/2000], Content Loss: 19.8584, Style Loss: 103.7824
Step [80/2000], Content Loss: 20.5509, Style Loss: 95.5047
Step [90/2000], Content Loss: 21.1601, Style Loss: 88.7919
Step [100/2000], Content Loss: 21.6844, Style Loss: 83.1393
Step [110/2000], Content Loss: 22.1447, Style Loss: 78.2809
Step [120/2000], Content Loss: 22.5605, Style Loss: 74.0401
Step [130/2000], Content Loss: 22.9415, Style Loss: 70.2842
Step [140/2000], Content Loss: 23.2941, Style Loss: 66.9353
Step [150/2000], Content Loss: 23.6130, Style Loss: 63.9158
Step [160/2000], Content Loss: 23.9114, Style Loss: 61.1637
Step [170/2000], Content Loss: 24.1892, Style Loss: 58.6509
Step [180/2000], Content Loss: 24.4448, Style Loss: 56.3407
Step [190/2000], Content Loss: 24.6883, Style Loss: 54.1998
Step [200/2000], Content Loss: 24.9212, Style Loss: 52.2185
Step [210/2000], Content Loss: 25.1355, Style Loss: 50.3827
Step [220/2000], Content Loss: 25.3350, Style Loss: 48.6758
Step [230/2000], Content Loss: 25.5269, Style Loss: 47.0833
Step [240/2000], Content Loss: 25.7123, Style Loss: 45.5909
Step [250/2000], Content Loss: 25.8884, Style Loss: 44.1901
Step [260/2000], Content Loss: 26.0555, Style Loss: 42.8741
Step [270/2000], Content Loss: 26.2152, Style Loss: 41.6320
Step [280/2000], Content Loss: 26.3691, Style Loss: 40.4600
Step [290/2000], Content Loss: 26.5208, Style Loss: 39.3519
Step [300/2000], Content Loss: 26.6641, Style Loss: 38.3040
Step [310/2000], Content Loss: 26.8034, Style Loss: 37.3103
Step [320/2000], Content Loss: 26.9339, Style Loss: 36.3693
Step [330/2000], Content Loss: 27.0649, Style Loss: 35.4760
Step [340/2000], Content Loss: 27.1923, Style Loss: 34.6284
Step [350/2000], Content Loss: 27.3130, Style Loss: 33.8245
Step [360/2000], Content Loss: 27.4284, Style Loss: 33.0575
Step [370/2000], Content Loss: 27.5356, Style Loss: 32.3269
Step [380/2000], Content Loss: 27.6426, Style Loss: 31.6281
Step [390/2000], Content Loss: 27.7454, Style Loss: 30.9596
Step [400/2000], Content Loss: 27.8430, Style Loss: 30.3200
Step [410/2000], Content Loss: 27.9398, Style Loss: 29.7072
Step [420/2000], Content Loss: 28.0368, Style Loss: 29.1180
Step [430/2000], Content Loss: 28.1289, Style Loss: 28.5518
Step [440/2000], Content Loss: 28.2207, Style Loss: 28.0077
Step [450/2000], Content Loss: 28.3101, Style Loss: 27.4842
Step [460/2000], Content Loss: 28.4016, Style Loss: 26.9804
Step [470/2000], Content Loss: 28.4844, Style Loss: 26.4949
Step [480/2000], Content Loss: 28.5667, Style Loss: 26.0286
Step [490/2000], Content Loss: 28.6440, Style Loss: 25.5799
Step [500/2000], Content Loss: 28.7183, Style Loss: 25.1476
Step [510/2000], Content Loss: 28.7939, Style Loss: 24.7302
Step [520/2000], Content Loss: 28.8708, Style Loss: 24.3261
Step [530/2000], Content Loss: 28.9440, Style Loss: 23.9349
Step [540/2000], Content Loss: 29.0163, Style Loss: 23.5566
Step [550/2000], Content Loss: 29.0864, Style Loss: 23.1890
Step [560/2000], Content Loss: 29.1529, Style Loss: 22.8329
Step [570/2000], Content Loss: 29.2189, Style Loss: 22.4880
Step [580/2000], Content Loss: 29.2833, Style Loss: 22.1529
Step [590/2000], Content Loss: 29.3477, Style Loss: 21.8286
Step [600/2000], Content Loss: 29.4093, Style Loss: 21.5141
Step [610/2000], Content Loss: 29.4694, Style Loss: 21.2083
Step [620/2000], Content Loss: 29.5252, Style Loss: 20.9107
Step [630/2000], Content Loss: 29.5821, Style Loss: 20.6206
Step [640/2000], Content Loss: 29.6378, Style Loss: 20.3381
Step [650/2000], Content Loss: 29.6938, Style Loss: 20.0623
Step [660/2000], Content Loss: 29.7449, Style Loss: 19.7930
Step [670/2000], Content Loss: 29.7975, Style Loss: 19.5310
Step [680/2000], Content Loss: 29.8479, Style Loss: 19.2760
Step [690/2000], Content Loss: 29.8950, Style Loss: 19.0278
Step [700/2000], Content Loss: 29.9427, Style Loss: 18.7856
Step [710/2000], Content Loss: 29.9889, Style Loss: 18.5502
Step [720/2000], Content Loss: 30.0369, Style Loss: 18.3209
Step [730/2000], Content Loss: 30.0841, Style Loss: 18.0967
Step [740/2000], Content Loss: 30.1312, Style Loss: 17.8776
Step [750/2000], Content Loss: 30.1793, Style Loss: 17.6630
Step [760/2000], Content Loss: 30.2209, Style Loss: 17.4535
Step [770/2000], Content Loss: 30.2625, Style Loss: 17.2486
Step [780/2000], Content Loss: 30.3043, Style Loss: 17.0483
Step [790/2000], Content Loss: 30.3472, Style Loss: 16.8526
Step [800/2000], Content Loss: 30.3883, Style Loss: 16.6612
Step [810/2000], Content Loss: 30.4279, Style Loss: 16.4737
Step [820/2000], Content Loss: 30.4663, Style Loss: 16.2899
Step [830/2000], Content Loss: 30.5036, Style Loss: 16.1099
Step [840/2000], Content Loss: 30.5427, Style Loss: 15.9336
Step [850/2000], Content Loss: 30.5801, Style Loss: 15.7608
Step [860/2000], Content Loss: 30.6190, Style Loss: 15.5913
Step [870/2000], Content Loss: 30.6561, Style Loss: 15.4249
Step [880/2000], Content Loss: 30.6927, Style Loss: 15.2619
Step [890/2000], Content Loss: 30.7275, Style Loss: 15.1023
Step [900/2000], Content Loss: 30.7620, Style Loss: 14.9457
Step [910/2000], Content Loss: 30.7954, Style Loss: 14.7917
Step [920/2000], Content Loss: 30.8298, Style Loss: 14.6399
Step [930/2000], Content Loss: 30.8670, Style Loss: 14.4906
Step [940/2000], Content Loss: 30.9016, Style Loss: 14.3440
Step [950/2000], Content Loss: 30.9369, Style Loss: 14.1998
Step [960/2000], Content Loss: 30.9720, Style Loss: 14.0581
Step [970/2000], Content Loss: 31.0021, Style Loss: 13.9193
Step [980/2000], Content Loss: 31.0370, Style Loss: 13.7825
Step [990/2000], Content Loss: 31.0691, Style Loss: 13.6480
Step [1000/2000], Content Loss: 31.0998, Style Loss: 13.5158
Step [1010/2000], Content Loss: 31.1302, Style Loss: 13.3861
Step [1020/2000], Content Loss: 31.1605, Style Loss: 13.2587
Step [1030/2000], Content Loss: 31.1915, Style Loss: 13.1332
Step [1040/2000], Content Loss: 31.2220, Style Loss: 13.0099
Step [1050/2000], Content Loss: 31.2528, Style Loss: 12.8889
Step [1060/2000], Content Loss: 31.2860, Style Loss: 12.7697
Step [1070/2000], Content Loss: 31.3174, Style Loss: 12.6525
Step [1080/2000], Content Loss: 31.3475, Style Loss: 12.5375
Step [1090/2000], Content Loss: 31.3775, Style Loss: 12.4245
Step [1100/2000], Content Loss: 31.4046, Style Loss: 12.3129
Step [1110/2000], Content Loss: 31.4350, Style Loss: 12.2038
Step [1120/2000], Content Loss: 31.4598, Style Loss: 12.0956
Step [1130/2000], Content Loss: 31.4878, Style Loss: 11.9894
Step [1140/2000], Content Loss: 31.5149, Style Loss: 11.8847
Step [1150/2000], Content Loss: 31.5406, Style Loss: 11.7818
Step [1160/2000], Content Loss: 31.5659, Style Loss: 11.6805
Step [1170/2000], Content Loss: 31.5901, Style Loss: 11.5803
Step [1180/2000], Content Loss: 31.6137, Style Loss: 11.4822
Step [1190/2000], Content Loss: 31.6345, Style Loss: 11.3851
Step [1200/2000], Content Loss: 31.6543, Style Loss: 11.2900
Step [1210/2000], Content Loss: 31.6787, Style Loss: 11.1968
Step [1220/2000], Content Loss: 31.7000, Style Loss: 11.1037
Step [1230/2000], Content Loss: 31.7205, Style Loss: 11.0116
Step [1240/2000], Content Loss: 31.7422, Style Loss: 10.9210
Step [1250/2000], Content Loss: 31.7633, Style Loss: 10.8319
Step [1260/2000], Content Loss: 31.7867, Style Loss: 10.7446
Step [1270/2000], Content Loss: 31.8046, Style Loss: 10.6565
Step [1280/2000], Content Loss: 31.8247, Style Loss: 10.5699
Step [1290/2000], Content Loss: 31.8469, Style Loss: 10.4858
Step [1300/2000], Content Loss: 31.8646, Style Loss: 10.4015
Step [1310/2000], Content Loss: 31.8859, Style Loss: 10.3201
Step [1320/2000], Content Loss: 31.9010, Style Loss: 10.2365
Step [1330/2000], Content Loss: 31.9236, Style Loss: 10.1575
Step [1340/2000], Content Loss: 31.9461, Style Loss: 10.0792
Step [1350/2000], Content Loss: 31.9616, Style Loss: 9.9980
Step [1360/2000], Content Loss: 31.9880, Style Loss: 9.9236
Step [1370/2000], Content Loss: 32.0038, Style Loss: 9.8461
Step [1380/2000], Content Loss: 32.0191, Style Loss: 9.7687
Step [1390/2000], Content Loss: 32.0434, Style Loss: 9.6970
Step [1400/2000], Content Loss: 32.0572, Style Loss: 9.6203
Step [1410/2000], Content Loss: 32.0787, Style Loss: 9.5496
Step [1420/2000], Content Loss: 32.0955, Style Loss: 9.4771
Step [1430/2000], Content Loss: 32.1123, Style Loss: 9.4056
Step [1440/2000], Content Loss: 32.1289, Style Loss: 9.3349
Step [1450/2000], Content Loss: 32.1441, Style Loss: 9.2636
Step [1460/2000], Content Loss: 32.1628, Style Loss: 9.1949
Step [1470/2000], Content Loss: 32.1851, Style Loss: 9.1302
Step [1480/2000], Content Loss: 32.1958, Style Loss: 9.0589
Step [1490/2000], Content Loss: 32.2141, Style Loss: 8.9938
Step [1500/2000], Content Loss: 32.2303, Style Loss: 8.9282
Step [1510/2000], Content Loss: 32.2414, Style Loss: 8.8597
Step [1520/2000], Content Loss: 32.2560, Style Loss: 8.7944
Step [1530/2000], Content Loss: 32.2785, Style Loss: 8.7337
Step [1540/2000], Content Loss: 32.2986, Style Loss: 8.6751
Step [1550/2000], Content Loss: 32.2955, Style Loss: 8.6001
Step [1560/2000], Content Loss: 32.3232, Style Loss: 8.5438
Step [1570/2000], Content Loss: 32.3409, Style Loss: 8.4860
Step [1580/2000], Content Loss: 32.3442, Style Loss: 8.4177
Step [1590/2000], Content Loss: 32.3604, Style Loss: 8.3581
Step [1600/2000], Content Loss: 32.3871, Style Loss: 8.3062
Step [1610/2000], Content Loss: 32.3841, Style Loss: 8.2353
Step [1620/2000], Content Loss: 32.4114, Style Loss: 8.1829
Step [1630/2000], Content Loss: 32.4267, Style Loss: 8.1247
Step [1640/2000], Content Loss: 32.4401, Style Loss: 8.0669
Step [1650/2000], Content Loss: 32.4480, Style Loss: 8.0066
Step [1660/2000], Content Loss: 32.4796, Style Loss: 7.9656
Step [1670/2000], Content Loss: 32.4754, Style Loss: 7.8967
Step [1680/2000], Content Loss: 32.4839, Style Loss: 7.8374
Step [1690/2000], Content Loss: 32.5063, Style Loss: 7.7878
Step [1700/2000], Content Loss: 32.5246, Style Loss: 7.7381
Step [1710/2000], Content Loss: 32.5257, Style Loss: 7.6759
Step [1720/2000], Content Loss: 32.5456, Style Loss: 7.6262
Step [1730/2000], Content Loss: 32.5680, Style Loss: 7.5811
Step [1740/2000], Content Loss: 32.5655, Style Loss: 7.5176
Step [1750/2000], Content Loss: 32.5831, Style Loss: 7.4672
Step [1760/2000], Content Loss: 32.6070, Style Loss: 7.4232
Step [1770/2000], Content Loss: 32.6441, Style Loss: 7.4071
Step [1780/2000], Content Loss: 32.6931, Style Loss: 7.4527
Step [1790/2000], Content Loss: 32.7056, Style Loss: 7.4441
Step [1800/2000], Content Loss: 32.6304, Style Loss: 7.2250
Step [1810/2000], Content Loss: 32.6647, Style Loss: 7.1710
Step [1820/2000], Content Loss: 32.6658, Style Loss: 7.1150
Step [1830/2000], Content Loss: 32.6795, Style Loss: 7.0659
Step [1840/2000], Content Loss: 32.6897, Style Loss: 7.0176
Step [1850/2000], Content Loss: 32.7024, Style Loss: 6.9711
Step [1860/2000], Content Loss: 32.7121, Style Loss: 6.9235
Step [1870/2000], Content Loss: 32.7327, Style Loss: 6.8816
Step [1880/2000], Content Loss: 32.7356, Style Loss: 6.8324
Step [1890/2000], Content Loss: 32.7485, Style Loss: 6.7878
Step [1900/2000], Content Loss: 32.7634, Style Loss: 6.7444
Step [1910/2000], Content Loss: 32.7753, Style Loss: 6.6990
Step [1920/2000], Content Loss: 32.7872, Style Loss: 6.6547
Step [1930/2000], Content Loss: 32.8038, Style Loss: 6.6145
Step [1940/2000], Content Loss: 32.8169, Style Loss: 6.5722
Step [1950/2000], Content Loss: 32.8173, Style Loss: 6.5240
Step [1960/2000], Content Loss: 32.8359, Style Loss: 6.4847
Step [1970/2000], Content Loss: 32.8538, Style Loss: 6.4470
Step [1980/2000], Content Loss: 32.8599, Style Loss: 6.4017
Step [1990/2000], Content Loss: 32.8634, Style Loss: 6.3566
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4
denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
img = target.clone().squeeze()
img = denorm(img).clamp_(0, 1)
plt.figure()
imshow(img, title='Target Image')

输出:
PyTorch 第六章

2.Generative Adversarial Networks

batch_size=32
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5),
                        std=(0.5, 0.5, 0.5))
])

mnist_data = torchvision.datasets.MNIST("./mnist_data", train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset=mnist_data,
                                         batch_size=batch_size,
                                         shuffle=True)
image_size = 784

hidden_size = 256
# discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
)

latent_size = 64
# Generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

D = D.to(device)
G = G.to(device)

loss_fn = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

开始训练


def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

total_step = len(dataloader)
num_epochs = 200
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        batch_size = images.size(0)
        images = images.reshape(batch_size, image_size).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        outputs = D(images)
        d_loss_real = loss_fn(outputs, real_labels)
        real_score = outputs
        
        # 开始生成fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images.detach())
        d_loss_fake = loss_fn(outputs, fake_labels)
        fake_score = outputs
        
        # 开始优化discriminator
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # 开始优化generator
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = loss_fn(outputs, real_labels)
        
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if i % 1000 == 0:
            print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}"
                 .format(epoch, num_epochs, i, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item())

部分截图如下:

Epoch [0/200], Step [0/1875], d_loss: 0.6669, g_loss: 2.9577, D(x): 0.76, D(G(z)): 0.15
Epoch [0/200], Step [1000/1875], d_loss: 0.1716, g_loss: 3.0008, D(x): 0.93, D(G(z)): 0.09
Epoch [1/200], Step [0/1875], d_loss: 0.1716, g_loss: 4.1396, D(x): 0.93, D(G(z)): 0.02
Epoch [1/200], Step [1000/1875], d_loss: 0.0202, g_loss: 5.1296, D(x): 1.00, D(G(z)): 0.02
Epoch [2/200], Step [0/1875], d_loss: 0.2070, g_loss: 3.7713, D(x): 0.93, D(G(z)): 0.08
Epoch [2/200], Step [1000/1875], d_loss: 0.0829, g_loss: 4.9163, D(x): 0.99, D(G(z)): 0.07
Epoch [3/200], Step [0/1875], d_loss: 0.2986, g_loss: 3.6197, D(x): 0.90, D(G(z)): 0.03
Epoch [3/200], Step [1000/1875], d_loss: 0.4204, g_loss: 2.2956, D(x): 0.90, D(G(z)): 0.14
Epoch [4/200], Step [0/1875], d_loss: 0.4453, g_loss: 5.1677, D(x): 0.80, D(G(z)): 0.02
Epoch [4/200], Step [1000/1875], d_loss: 0.1900, g_loss: 2.7722, D(x): 0.93, D(G(z)): 0.10
Epoch [5/200], Step [0/1875], d_loss: 0.3418, g_loss: 2.4469, D(x): 1.00, D(G(z)): 0.21
Epoch [5/200], Step [1000/1875], d_loss: 0.4460, g_loss: 2.4152, D(x): 0.90, D(G(z)): 0.18
Epoch [6/200], Step [0/1875], d_loss: 0.3142, g_loss: 4.0145, D(x): 0.93, D(G(z)): 0.13
Epoch [6/200], Step [1000/1875], d_loss: 0.5893, g_loss: 3.9873, D(x): 0.97, D(G(z)): 0.31
Epoch [7/200], Step [0/1875], d_loss: 0.3118, g_loss: 3.2590, D(x): 0.88, D(G(z)): 0.10
Epoch [7/200], Step [1000/1875], d_loss: 0.5169, g_loss: 2.8562, D(x): 0.84, D(G(z)): 0.20
Epoch [8/200], Step [0/1875], d_loss: 0.1886, g_loss: 3.0765, D(x): 0.93, D(G(z)): 0.05
Epoch [8/200], Step [1000/1875], d_loss: 0.5987, g_loss: 3.0972, D(x): 0.86, D(G(z)): 0.17
Epoch [9/200], Step [0/1875], d_loss: 0.7312, g_loss: 2.5704, D(x): 0.93, D(G(z)): 0.30

fake images

z = torch.randn(1, latent_size).to(device)
fake_images = G(z).view(28, 28).data.cpu().numpy()
plt.imshow(fake_images)

<matplotlib.image.AxesImage at 0x7f55b00136d8>
PyTorch 第六章
真实图片

plt.imshow(images[0].view(28,28).data.cpu().numpy())
<matplotlib.image.AxesImage at 0x7f55b09e7f60>

PyTorch 第六章

3.DCGAN

UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS

图片下载地址 https://drive.google.com/drive/folders/0B7EVK8r0v71pbWNEUjJKdDQ3dGc

import torchvision.utils as vutils
# !ls celeba/img_align_celeba/img_align_celeba_png
image_size=64
batch_size=128
dataroot="celeba/img_align_celeba"
num_workers = 2
dataset = torchvision.datasets.ImageFolder(root=dataroot, transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
real_batch=next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis=("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(), (1,2,0)))

<matplotlib.image.AxesImage at 0x7f6db16dafd0>
PyTorch 第六章

我们把模型的所有参数都初始化城mean=0, std=0.2

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
nz = 100 # latent vector的大小
ngf = 64 # generator feature map size
ndf = 64 # discriminator feature map size
nc = 3 # color channels

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            # torch.nn.ConvTranspose2d(in_channels, out_channels, 
            # kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1)
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# Now, we can instantiate the generator and apply the weights_init function. Check out the printed model to see how the generator object is structured.

# Create the generator
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)

Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
Discriminator

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
# Now, as with the generator, we can create the discriminator, apply the weights_init function, and print the model’s structure.

# Create the Discriminator
netD = Discriminator().to(device)


# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

开始训练

lr = 0.0002
beta1 = 0.5

loss_fn = nn.BCELoss()
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
d_optimizer = torch.optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
g_optimizer = torch.optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
num_epochs = 5
G_losses = []
D_losses = []
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader):
        # 训练discriminator, maximize log(D(x)) + log(1-D(G(z)))
        
        # 首先训练真实图片
        netD.zero_grad()
        
        real_images = data[0].to(device)
        b_size = real_images.size(0)
        label = torch.ones(b_size).to(device)
        output = netD(real_images).view(-1)
        
        
        real_loss = loss_fn(output, label)
        real_loss.backward()
        D_x = output.mean().item()
        
        
        # 然后训练生成的假图片
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake_images = netG(noise)
        label.fill_(0)
        output = netD(fake_images.detach()).view(-1)
        fake_loss = loss_fn(output, label)
        fake_loss.backward()
        D_G_z1 = output.mean().item()
        loss_D = real_loss + fake_loss
        d_optimizer.step()
        
        # 训练Generator 
        netG.zero_grad()
        label.fill_(1)
        output = netD(fake_images).view(-1)
        loss_G = loss_fn(output, label)
        loss_G.backward()
        D_G_z2 = output.mean().item()
        g_optimizer.step()
        
        if i % 50 == 0:
            print("[{}/{}] [{}/{}] Loss_D: {:.4f} Loss_G {:.4f} D(x): {:.4f} D(G(z)): {:.4f}/{:.4f}"
                 .format(epoch, num_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), D_x, D_G_z1, D_G_z2))
        
        G_losses.append(loss_G.item())
        D_losses.append(loss_D.item())
        

输出部分截图如下:

[0/5] [0/1583] Loss_D: 1.7977 Loss_G 2.8596 D(x): 0.3357 D(G(z)): 0.3494/0.0786
[0/5] [50/1583] Loss_D: 0.4748 Loss_G 30.1861 D(x): 0.7715 D(G(z)): 0.0000/0.0000
[0/5] [100/1583] Loss_D: 0.1432 Loss_G 8.7877 D(x): 0.9865 D(G(z)): 0.1092/0.0016
[0/5] [150/1583] Loss_D: 0.5332 Loss_G 6.9773 D(x): 0.8701 D(G(z)): 0.2674/0.0030
[0/5] [200/1583] Loss_D: 1.5008 Loss_G 8.1102 D(x): 0.4722 D(G(z)): 0.0029/0.0011
[0/5] [250/1583] Loss_D: 0.3476 Loss_G 5.5318 D(x): 0.8942 D(G(z)): 0.1540/0.0132
[0/5] [300/1583] Loss_D: 0.6494 Loss_G 5.9788 D(x): 0.9072 D(G(z)): 0.3348/0.0124
[0/5] [350/1583] Loss_D: 0.8482 Loss_G 5.6696 D(x): 0.8947 D(G(z)): 0.4554/0.0091
[0/5] [400/1583] Loss_D: 0.5689 Loss_G 3.3358 D(x): 0.7856 D(G(z)): 0.1807/0.0647
[0/5] [450/1583] Loss_D: 0.8698 Loss_G 7.5017 D(x): 0.8675 D(G(z)): 0.4281/0.0022
[0/5] [500/1583] Loss_D: 0.3542 Loss_G 3.1888 D(x): 0.8573 D(G(z)): 0.1214/0.0587
[0/5] [550/1583] Loss_D: 0.3387 Loss_G 3.9772 D(x): 0.7958 D(G(z)): 0.0605/0.0351
[0/5] [600/1583] Loss_D: 0.6330 Loss_G 4.3450 D(x): 0.7693 D(G(z)): 0.1875/0.0238
[0/5] [650/1583] Loss_D: 0.6735 Loss_G 4.8144 D(x): 0.6305 D(G(z)): 0.0358/0.0166
[0/5] [700/1583] Loss_D: 0.3484 Loss_G 4.6406 D(x): 0.8652 D(G(z)): 0.1372/0.0182
[0/5] [750/1583] Loss_D: 0.5287 Loss_G 5.8325 D(x): 0.8684 D(G(z)): 0.2675/0.0056
[0/5] [800/1583] Loss_D: 0.6363 Loss_G 3.1169 D(x): 0.6298 D(G(z)): 0.0332/0.0755
with torch.no_grad():
    fake = netG(fixed_noise).detach().cpu()
# fake
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(30,30))
plt.subplot(1,2,1)
plt.axis=("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis=("off")
plt.title("Fake Images")
plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True), (1,2,0)))
plt.show()

部分输出如下:
PyTorch 第六章

上一篇:Sql 六亿数据表和三亿数据表关联查找中间的数据


下一篇:多线程入门之矩阵乘法