基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练

直接上代码

先上整个的代码

import torch
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#  参考:https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187

class Mnist_Net(nn.Module):
    '''
    定义网络
    '''
    def __init__(self):
        super(Mnist_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        # 激活函数
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        # 返回结果
        return F.log_softmax(x)

def training_net(epoch,network,train_loader,optimizer,train_losses, train_counter,log_interval):
    '''
    一个种群训练一代
    :param epoch: 用于现实到第几个代了
    :param network: 模型对象
    :param train_loader: 训练数据对象
    :param optimizer: 优化器对象
    :param train_losses:
    :param train_counter:
    :param log_interval:
    :return:
    '''
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        # 将一个图片传入到网络中,得到out结果
        output = network(data)
        # 计算LOSS
        loss = F.nll_loss(output, target)
        # 反向传播LOSS
        loss.backward()
        # 优化器
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
            # 保存网络模型
            torch.save(network.state_dict(), './model.pth')
            # 保存优化器结果
            torch.save(optimizer.state_dict(), './optimizer.pth')


def testing_net(network, test_loader,test_losses):
    '''
    测试集执行
    :param network:
    :param test_loader:
    :param test_losses:
    :return:
    '''
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # 首先得到out结果
            output = network(data)
            # 计算LOSS
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))


def view_dataset_figure(test_loader):
    '''
    展示训练和测试的数据图
    :param test_loader:
    :return:
    '''
    # 让我们看看一批测试数据由什么组成。
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    print(example_targets)
    print(example_data.shape)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Ground Truth: {}".format(example_targets[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def show_loss_line_figure(train_counter,train_losses,test_counter,test_losses):
    '''
    展示LOSS曲线
    :param train_counter:
    :param train_losses:
    :param test_counter:
    :param test_losses:
    :return:
    '''
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()


def show_predict_result(network,test_loader):
    '''
    展示预测数据的结果,目前是用的test数据集中的数据
    :param network:
    :param test_loader:
    :return:
    '''
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    with torch.no_grad():
        output = network(example_data)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Prediction: {}".format(
            output.data.max(1, keepdim=True)[1][i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def execute_through_new():
    '''
    新的执行训练
    :return:
    '''
    n_epochs = 3
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    view_dataset_figure(test_loader_obj)
    network_obj = Mnist_Net()
    optimizer_obj = optim.SGD(network_obj.parameters(), lr=learning_rate,momentum=momentum)
    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    testing_net(network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 训练一代
        training_net(epoch, network_obj, train_loader_obj, optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        # 测试一代
        testing_net(network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(network_obj,test_loader_obj)


def execute_through_checkpoint():
    '''
    基于断点的执行训练
    :return:
    '''
    n_epochs = 30
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    # 加载数据
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),batch_size=batch_size_train, shuffle=True)
    # 加载数据
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    # 查看数据
    view_dataset_figure(test_loader_obj)
    # 形成网络对象
    continued_network_obj = Mnist_Net()
    # 形成优化器对象
    continued_optimizer_obj = optim.SGD(continued_network_obj.parameters(), lr=learning_rate,momentum=momentum)
    # 重载断点
    network_state_dict = torch.load('model.pth')
    continued_network_obj.load_state_dict(network_state_dict)
    optimizer_state_dict = torch.load('optimizer.pth')
    continued_optimizer_obj.load_state_dict(optimizer_state_dict)

    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    # 测试一下测试集 Test set: Avg. loss: 0.0347, Accuracy: 9887/10000 (99%)
    testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 每个epoch,test一下
        # 训练网络
        training_net(epoch, continued_network_obj, train_loader_obj, continued_optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(continued_network_obj,test_loader_obj)

### 主入口
if __name__ == '__main__':
    # 情况一:训练全新的模型;
    # execute_through_new()
    # 情况二:在断点的基础上,接着训练
    execute_through_checkpoint()

算法流程

口号:2【加数据、定模型】+2【训练4、测试2】
基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练
这是主体流程,主要是训练和测试2大步骤,其中训练主要包括了4个环节:网络运行、LOSS计算、反向传播、优化;测试包括了2个环节:网络运行、计算LOSS;

讨论网络模型定义

构建5层,包括两个卷积层,一个Dropout层(降低过拟合),两个线性层,最后返回F.log_softmax(x)。其中,需要去了解Net是集成自nn.Module。
关于nn.Module的详细介绍会在后面的章节展开。

主要参考资料

https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187
上一篇:UI5-学习篇-4-SCP-SAP WEB IDE登录


下一篇:rtx3060ti相当于什么水平?rtx3060ti显卡评测