卷积神经网络与 传统神经 网络的训练模块基本一致,网络 模型差异较大。
一 读取数据
# 导包
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
#读取 数据
# 定义超参数
input_size = 28 # 图像大小:28 * 28
num_classes = 10 # 标签的种类
num_epochs = 3 # 迭代的次数
batch_size = 64 # 每个批次的大小,即每64章图片一块进行一次训练
# 加载训练集
train_dataset = datasets.MNIST(
root='./data',
train=True,
transform=transforms.ToTensor(),
download=True
)
# 记载测试集
test_dataset = datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor())
# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
)
前面是导包
数据源还是mnist,分为训练集与测试集。使用DataLoader来构建batch数据。
二 搭建卷积神经网络模型
#网络 模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # 输入大小 (1,28,28)
nn.Conv2d(
in_channels=1, # 说明是灰度图
out_channels=16, # 要得到多少个特征图
kernel_size=5, # 卷积核的大小
stride=1, # 步长
padding=2), # 边缘填充的大小
nn.ReLU(), #relu层
nn.MaxPool2d(kernel_size=2) # 池化操作 (2 * 2) 输出结果为: (16,14,14)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),#16 是输入,32 是特征图
nn.ReLU(), #relu层
nn.MaxPool2d(2)) # 输出 (32, 7, 7)
self.out = nn.Linear(32 * 7 * 7, 10) # 全连接输入分类
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten操作,结果为 (batch_size, 32*7*7)
output = self.out(x)
return output
#准确率
def accuracy(predictions,labels):
pred = torch.max(predictions.data,1)[1]
rights = pred.eq(labels.data.view_as(pred)).sum()
return rights,len(labels)
这里面的主要 参数 ,conv1与conv2 里面的 需要结合老师的视频区理解。还有前向传播调用out之前的view 变形操作。要结合上一个节的从矩阵降维到全连接理解 。
三 训练网络模型
# 实例化
net = CNN()
# 选择损失函数
criterion = nn.CrossEntropyLoss()
# 选择优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) # 定义优化器,采用随机梯度下降算法
# 开始进行训练
for epoch in range(num_epochs):
train_right = [] # 保存当前epoch的结果,和之前定义一个保存loss的是一个道理
for batch_idx, (data, target) in enumerate(train_loader):
net.train()
output = net(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step() # 优化器调用step(),不是loss
right = accuracy(output, target)
train_right.append(right)
if batch_idx % 100 == 0:#每100次到验证集看效果
net.eval()
val_right = []
for (data, target) in test_loader:
output = net(data)
right = accuracy(output, target)
val_right.append(right)
# 准确率的计算
train_rate = (sum([tup[0] for tup in train_right]), sum(tup[1] for tup in train_right))
val_rate = (sum([tup[0] for tup in val_right]), sum(tup[1] for tup in val_right))
print('当前epoch:{}[{}/{}({:.0f}%)]\t 损失:{:.6f}\t 训练集准确率:{:.2f}%\t 测试集准确率:{:.2f}%'.format(
epoch, batch_idx * batch_size, len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.data,
100.0 * train_rate[0].numpy() / train_rate[1],
100.0 * val_rate[0].numpy() / val_rate[1]
))
没100次,看验证集输出结果:
当前epoch:0[0/60000(0%)] 损失:2.300367 训练集准确率:12.50% 测试集准确率:10.14%
当前epoch:0[6400/60000(11%)] 损失:0.280655 训练集准确率:74.13% 测试集准确率:92.48%
当前epoch:0[12800/60000(21%)] 损失:0.166317 训练集准确率:83.76% 测试集准确率:95.60%
当前epoch:0[19200/60000(32%)] 损失:0.105674 训练集准确率:87.71% 测试集准确率:95.65%
当前epoch:0[25600/60000(43%)] 损失:0.094606 训练集准确率:89.83% 测试集准确率:97.22%
当前epoch:0[32000/60000(53%)] 损失:0.065384 训练集准确率:91.26% 测试集准确率:97.60%
当前epoch:0[38400/60000(64%)] 损失:0.049964 训练集准确率:92.25% 测试集准确率:97.51%
当前epoch:0[44800/60000(75%)] 损失:0.035163 训练集准确率:93.01% 测试集准确率:97.83%
当前epoch:0[51200/60000(85%)] 损失:0.055695 训练集准确率:93.56% 测试集准确率:98.14%
当前epoch:0[57600/60000(96%)] 损失:0.014890 训练集准确率:94.03% 测试集准确率:97.77%
当前epoch:1[0/60000(0%)] 损失:0.081240 训练集准确率:93.75% 测试集准确率:98.20%
当前epoch:1[6400/60000(11%)] 损失:0.049458 训练集准确率:98.04% 测试集准确率:98.24%
当前epoch:1[12800/60000(21%)] 损失:0.026402 训练集准确率:98.12% 测试集准确率:98.18%
当前epoch:1[19200/60000(32%)] 损失:0.056982 训练集准确率:98.11% 测试集准确率:98.49%
当前epoch:1[25600/60000(43%)] 损失:0.098775 训练集准确率:98.13% 测试集准确率:98.63%
当前epoch:1[32000/60000(53%)] 损失:0.119748 训练集准确率:98.15% 测试集准确率:98.26%
当前epoch:1[38400/60000(64%)] 损失:0.024341 训练集准确率:98.18% 测试集准确率:98.49%
当前epoch:1[44800/60000(75%)] 损失:0.017717 训练集准确率:98.20% 测试集准确率:97.95%
当前epoch:1[51200/60000(85%)] 损失:0.084650 训练集准确率:98.20% 测试集准确率:98.45%
当前epoch:1[57600/60000(96%)] 损失:0.014650 训练集准确率:98.18% 测试集准确率:98.68%
当前epoch:2[0/60000(0%)] 损失:0.089021 训练集准确率:96.88% 测试集准确率:98.54%
当前epoch:2[6400/60000(11%)] 损失:0.048318 训练集准确率:98.72% 测试集准确率:98.68%
当前epoch:2[12800/60000(21%)] 损失:0.051317 训练集准确率:98.71% 测试集准确率:98.62%
当前epoch:2[19200/60000(32%)] 损失:0.033962 训练集准确率:98.67% 测试集准确率:98.53%
当前epoch:2[25600/60000(43%)] 损失:0.025890 训练集准确率:98.72% 测试集准确率:98.79%
当前epoch:2[32000/60000(53%)] 损失:0.007487 训练集准确率:98.72% 测试集准确率:98.57%
当前epoch:2[38400/60000(64%)] 损失:0.015440 训练集准确率:98.74% 测试集准确率:98.81%
当前epoch:2[44800/60000(75%)] 损失:0.006676 训练集准确率:98.73% 测试集准确率:98.84%
当前epoch:2[51200/60000(85%)] 损失:0.034487 训练集准确率:98.72% 测试集准确率:98.85%
当前epoch:2[57600/60000(96%)] 损失:0.042631 训练集准确率:98.73% 测试集准确率:98.73%