使用 PyTorch 实现简化版 GoogLeNet 进行 MNIST 图像分类

介绍

        本文将介绍如何使用 PyTorch 实现一个简化版的 GoogLeNet 网络来进行 MNIST 图像分类。GoogLeNet 是 Google 提出的深度卷积神经网络(CNN),其通过 Inception 模块大大提高了计算效率并提升了分类性能。我们将实现一个简化版的 GoogLeNet,用于处理 MNIST 数据集,该数据集由手写数字图片组成,适合用于小规模的图像分类任务。

项目结构

        我们将代码分为两个部分:

  • 训练脚本 train.py:包括数据加载、模型构建、训练过程等。
  • 测试脚本 test.py:用于加载训练好的模型并在测试集上评估性能。

项目依赖

        在开始之前,我们需要安装以下 Python 库:

  • torch:PyTorch 深度学习框架
  • torchvision:提供数据加载和图像变换功能
  • matplotlib:用于可视化

        可以通过以下命令安装所有依赖:

pip install -r requirements.txt

  requirements.txt 文件内容如下:

torch==2.0.1
torchvision==0.15.0
matplotlib==3.6.3

数据预处理与加载

1. 数据加载和预处理

        在训练模型之前,我们需要对 MNIST 数据集进行预处理。以下是数据加载和预处理的代码:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def get_data_loader(batch_size=64, train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围
    ])

    dataset = datasets.MNIST(root='./data', train=train, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

        这里,我们使用了 transforms.Compose 来进行数据预处理,包括将图像转换为 Tensor 格式,并进行归一化处理。


训练部分:train.py

2. 模型定义:简化版 GoogLeNet

        为了在 MNIST 数据集上训练,我们构建了一个简化版的 GoogLeNet,包含三个 Inception 模块和一个全连接层。每个 Inception 模块由一个卷积层和一个最大池化层组成。简化的 GoogLeNet 模型如下:

import torch.nn as nn

class SimpleGoogLeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleGoogLeNet, self).__init__()

        # 第一个 Inception 模块
        self.inception1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 第二个 Inception 模块
        self.inception2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 第三个 Inception 模块
        self.inception3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 分类器:全连接层 + Dropout 层
        self.fc = nn.Sequential(
            nn.Linear(128 * 3 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.inception3(x)
        x = x.view(x.size(0), -1)  # 展平输入
        x = self.fc(x)
        return x

3. 训练函数

        训练过程包括前向传播、反向传播和优化。我们将使用 Adam 优化器和 交叉熵损失 来训练模型:

import torch.optim as optim
from tqdm import tqdm

def train_epoch(model, device, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix(loss=running_loss / (total // 64), accuracy=100 * correct / total)

    return running_loss / len(train_loader), 100 * correct / total

4. 训练脚本:train.py

        训练脚本将包括模型的定义、数据加载、训练过程等:

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGoogLeNet().to(device)
    
    train_loader = get_data_loader(batch_size=64, train=True)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    epochs = 10
    for epoch in range(epochs):
        loss, accuracy = train_epoch(model, device, train_loader, criterion, optimizer)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.2f}%")
        
    torch.save(model.state_dict(), "simplified_googlenet.pth")  # 保存模型

if __name__ == '__main__':
    train_model()


测试部分:test.py

5. 测试函数

        在测试阶段,我们将使用 torch.no_grad() 禁用梯度计算,提高推理速度,并计算模型在测试集上的准确率:

def test_model(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

6. 测试脚本:test.py

        测试脚本将加载训练好的模型并对测试集进行评估:

import torch
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def get_test_loader(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGoogLeNet().to(device)
    model.load_state_dict(torch.load("simplified_googlenet.pth"))  # 加载训练好的模型
    
    test_loader = get_test_loader(batch_size=64)
    test_model(model, device, test_loader)

if __name__ == '__main__':
    main()

总结

        本文介绍了如何使用 PyTorch 实现简化版的 GoogLeNet,并将代码分为训练(train.py)和测试(test.py)部分。在训练脚本中,我们定义了一个简化版的 GoogLeNet,训练模型并保存训练结果。而在测试脚本中,我们加载训练好的模型并在测试集上进行评估。

        通过这些步骤,我们能够快速地实现一个高效的图像分类模型,并在 MNIST 数据集上进行训练与测试。

完整项目
GitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.https://github.com/qxd-ljy/GoogLeNet-PyTorchGitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.https://github.com/qxd-ljy/GoogLeNet-PyTorch

        希望这篇博客对你有所帮助,欢迎继续探索 PyTorch 和深度学习的更多应用!

上一篇:c++临时对象详解


下一篇:无人机飞手入门指南