直接上代码
先上整个的代码
import torch
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# 参考:https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187
class Mnist_Net(nn.Module):
'''
定义网络
'''
def __init__(self):
super(Mnist_Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
# 激活函数
x = F.relu(F.max_pool2d(self.conv1(x), 2))
# 激活函数
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
# 激活函数
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
# 返回结果
return F.log_softmax(x)
def training_net(epoch,network,train_loader,optimizer,train_losses, train_counter,log_interval):
'''
一个种群训练一代
:param epoch: 用于现实到第几个代了
:param network: 模型对象
:param train_loader: 训练数据对象
:param optimizer: 优化器对象
:param train_losses:
:param train_counter:
:param log_interval:
:return:
'''
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
# 将一个图片传入到网络中,得到out结果
output = network(data)
# 计算LOSS
loss = F.nll_loss(output, target)
# 反向传播LOSS
loss.backward()
# 优化器
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
# 保存网络模型
torch.save(network.state_dict(), './model.pth')
# 保存优化器结果
torch.save(optimizer.state_dict(), './optimizer.pth')
def testing_net(network, test_loader,test_losses):
'''
测试集执行
:param network:
:param test_loader:
:param test_losses:
:return:
'''
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
# 首先得到out结果
output = network(data)
# 计算LOSS
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
def view_dataset_figure(test_loader):
'''
展示训练和测试的数据图
:param test_loader:
:return:
'''
# 让我们看看一批测试数据由什么组成。
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
def show_loss_line_figure(train_counter,train_losses,test_counter,test_losses):
'''
展示LOSS曲线
:param train_counter:
:param train_losses:
:param test_counter:
:param test_losses:
:return:
'''
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red')
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()
def show_predict_result(network,test_loader):
'''
展示预测数据的结果,目前是用的test数据集中的数据
:param network:
:param test_loader:
:return:
'''
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Prediction: {}".format(
output.data.max(1, keepdim=True)[1][i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
def execute_through_new():
'''
新的执行训练
:return:
'''
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)
train_loader_obj = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader_obj = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
), batch_size=batch_size_test, shuffle=True
)
view_dataset_figure(test_loader_obj)
network_obj = Mnist_Net()
optimizer_obj = optim.SGD(network_obj.parameters(), lr=learning_rate,momentum=momentum)
train_losses_obj = []
train_counter_obj = []
test_losses_obj = []
test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
testing_net(network_obj, test_loader_obj, test_losses_obj)
for epoch in range(1, n_epochs + 1):
# 训练一代
training_net(epoch, network_obj, train_loader_obj, optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
# 测试一代
testing_net(network_obj, test_loader_obj, test_losses_obj)
#画一下训练曲线
show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
#做预测的可视化
show_predict_result(network_obj,test_loader_obj)
def execute_through_checkpoint():
'''
基于断点的执行训练
:return:
'''
n_epochs = 30
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.manual_seed(random_seed)
# 加载数据
train_loader_obj = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),batch_size=batch_size_train, shuffle=True)
# 加载数据
test_loader_obj = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
), batch_size=batch_size_test, shuffle=True
)
# 查看数据
view_dataset_figure(test_loader_obj)
# 形成网络对象
continued_network_obj = Mnist_Net()
# 形成优化器对象
continued_optimizer_obj = optim.SGD(continued_network_obj.parameters(), lr=learning_rate,momentum=momentum)
# 重载断点
network_state_dict = torch.load('model.pth')
continued_network_obj.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load('optimizer.pth')
continued_optimizer_obj.load_state_dict(optimizer_state_dict)
train_losses_obj = []
train_counter_obj = []
test_losses_obj = []
test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
# 测试一下测试集 Test set: Avg. loss: 0.0347, Accuracy: 9887/10000 (99%)
testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
for epoch in range(1, n_epochs + 1):
# 每个epoch,test一下
# 训练网络
training_net(epoch, continued_network_obj, train_loader_obj, continued_optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
#画一下训练曲线
show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
#做预测的可视化
show_predict_result(continued_network_obj,test_loader_obj)
### 主入口
if __name__ == '__main__':
# 情况一:训练全新的模型;
# execute_through_new()
# 情况二:在断点的基础上,接着训练
execute_through_checkpoint()
算法流程
口号:2【加数据、定模型】+2【训练4、测试2】
这是主体流程,主要是训练和测试2大步骤,其中训练主要包括了4个环节:网络运行、LOSS计算、反向传播、优化;测试包括了2个环节:网络运行、计算LOSS;
讨论网络模型定义
构建5层,包括两个卷积层,一个Dropout层(降低过拟合),两个线性层,最后返回F.log_softmax(x)。其中,需要去了解Net是集成自nn.Module。
关于nn.Module的详细介绍会在后面的章节展开。
主要参考资料
https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187