使用Flask部署自己的pytorch模型(猫品种分类模型)

使用Flask部署自己的pytorch模型(猫品种分类模型)

全部代码开源在YTY666ZSY/Flask_Cat_7classify — yty666zsy/Flask_Cat_7classify (github.com)

一、数据集准备

来自大佬的文章调用pytorch的resnet,训练出准确率高达96%的猫12类分类模型。 - 知乎 (zhihu.com),在其基础上进行修改的。

在视觉中国中使用爬虫来进行猫咪品种的爬取,爬取后的图片需要自己去检查有没有错误,清洗图片数据。

如下代码所示,需要修改file_path,指定保存地址,修改base_url,例如"buoumao"为布偶猫的拼音,如果想搜索其他品种的猫,直接更改拼音就可以。

import asyncio
import re  

import aiohttp
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.edge.options import Options


def ks_download_uel(image_urls):
    async def download_images(url_list):
        async with aiohttp.ClientSession() as session:
            global k
            for url in url_list:
                try:
                    async with session.get("https:" + url) as response:  # "https:" + url 进行网址拼接
                        response.raise_for_status()
                        file_path = fr"F:\project\猫12分类\data\孟买猫\{k}.jpg"  # 指定保存地址
                        with open(file_path, 'wb') as file:
                            while True:
                                chunk = await response.content.read(8192)
                                if not chunk:
                                    break
                                file.write(chunk)
                    print(f"已经完成 {k} 张")
                except Exception as e:
                    print(f"下载第 {k} 张出现错误 :{str(e)}")
                k += 1  # 为下一张做标记

    # 创建事件循环对象
    loop = asyncio.get_event_loop()
    # 调用异步函数
    loop.run_until_complete(download_images(image_urls))


if __name__ == '__main__':
    base_url = 'https://www.vcg.com/creative-image/mengmaimao/?page={page}'  # "buoumao"为布偶猫的拼音,如果想搜索其他品种的猫,直接更改拼音就可以
    edge_options = Options()
    edge_options.add_argument("--headless")  # 不显示浏览器敞口, 加快爬取速度。
    edge_options.add_argument("--no-sandbox")  # 防止启动失败
    driver = webdriver.Edge(options=edge_options)

    k = 1  # 为保存的每一种图片做标记
    for page in range(1, 5):  # 每一页150张,十页就够了。
        if page == 1:  # 目的是就打开一个网特,减少内存开销
            driver.get(base_url.format(page=page))  # 开始访问第page页
        elements = driver.find_elements(By.XPATH,
                                        '//*[@id="imageContent"]/section[1]')  # 将返回 //*[@id="imageContent"]/section/div 下的所有子标签元素
        urls_ls = []  # 所要的图片下载地址。
        for element in elements:
            html_source = element.get_attribute('outerHTML')
            urls_ls = re.findall('data-src="(.*?)"', str(html_source))  # 这里用了正则匹配,可以加快执行速度

        #  下面给大家推荐一个异步快速下载图片的方法, 建议这时候尽量多提供一下cpu和内存为程序
        ks_download_uel(urls_ls)

        driver.execute_script(f"window.open('{base_url.format(page=page)}', '_self')")  # 在当前窗口打开新网页,减少内存使用
    driver.quit()  # 在所有网页访问完成后退出 WebDriver

爬取后的图片保存在指定的位置

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

二、训练模型

如下面代码所示,我们使用res50的预训练模型,但是需要注意的是最后的线性层model.fc需要修改为自己需要的分类种类,train_data_path修改为自己data所在位置,需要说的是我们并不需要主动去划分测试集和训练集,我们只需要进行数据分类,在代码中会自动分类。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from tqdm import tqdm
from collections import Counter
import matplotlib.pyplot as plt  # 用于绘制图形

class Trainer:
    def __init__(self, model, device, train_loader, valid_loader, lr=0.0001):
        self.model = model.to(device)  # 将模型转移到设备上
        self.device = device
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.optimizer = optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)

    def train_one_epoch(self, epoch):
        self.model.train()
        correct_predictions = 0
        total_samples = 0
        epoch_loss = 0  # 记录当前轮次的损失
        
        for inputs, targets in tqdm(self.train_loader):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            predictions = outputs.argmax(dim=1)
            correct_predictions += (predictions == targets).sum().item()
            total_samples += targets.size(0)
            loss = nn.CrossEntropyLoss()(outputs, targets)
            epoch_loss += loss.item()  # 累加损失
            loss.backward()
            self.optimizer.step()
        
        accuracy = 100. * correct_predictions / total_samples
        average_loss = epoch_loss / len(self.train_loader)  # 计算平均损失
        print(f"Epoch {epoch}: Train Accuracy: {accuracy:.2f}%, Loss: {average_loss:.4f}")
        return average_loss  # 返回当前轮次的平均损失

    def validate(self):
        self.model.eval()
        correct_predictions = 0
        total_samples = 0
        total_loss = 0.0
        
        with torch.no_grad():
            for inputs, targets in self.valid_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = nn.CrossEntropyLoss()(outputs, targets)
                total_loss += loss.item()
                predictions = outputs.argmax(dim=1)
                correct_predictions += (predictions == targets).sum().item()
                total_samples += targets.size(0)
        
        accuracy = 100. * correct_predictions / total_samples
        average_loss = total_loss / len(self.valid_loader)  # 计算平均损失
        print(f"Validation Accuracy: {accuracy:.2f}%, Loss: {average_loss:.4f}")
        return accuracy, average_loss  # 返回准确率和损失


def create_data_loaders(train_root, batch_size):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(244, scale=(0.6, 1.0), ratio=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757])
    ])

    dataset = torchvision.datasets.ImageFolder(root=train_root, transform=transform)
    class_counts = Counter(dataset.targets)
    weights = [1.0 / class_counts[idx] for idx in dataset.targets]
    sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
    
    train_subset = Subset(dataset, list(sampler))
    valid_indices = [idx for idx in range(len(dataset)) if idx not in list(sampler)]
    valid_subset = Subset(dataset, valid_indices)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_subset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    learning_rate = 0.0001
    epochs = 30
    batch_size = 32
    train_data_path = r"E:\github\my_github\Flask_cat_7classify\data"

    train_loader, valid_loader = create_data_loaders(train_data_path, batch_size)

    model = torchvision.models.resnet50(weights='ResNet50_Weights.DEFAULT')
    model.fc = nn.Linear(2048, 7)  # 调整输出层以适应7个类别

    trainer = Trainer(model, device, train_loader, valid_loader)

    best_accuracy = 0.0
    best_model_state = None
    train_losses = []  # 记录每轮训练损失
    valid_losses = []  # 记录每轮验证损失

    for epoch in range(1, epochs + 1):
        train_loss = trainer.train_one_epoch(epoch)
        train_losses.append(train_loss)  # 存储训练损失
        accuracy, valid_loss = trainer.validate()
        valid_losses.append(valid_loss)  # 存储验证损失

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict()

    print(f"Best Validation Accuracy: {best_accuracy:.2f}%")
    torch.save(best_model_state, fr"E:\github\my_github\Flask_cat_7classify\best_model_train{best_accuracy:.2f}.pth")

    # 绘制损失变化图
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, epochs + 1), train_losses, label='Train Loss', color='blue')
    plt.plot(range(1, epochs + 1), valid_losses, label='Validation Loss', color='orange')
    plt.title('Loss Change Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.savefig('loss_plot.png')  # 保存图形为 PNG 文件
    plt.show()  # 显示图形

三、使用flask部署模型

如下面代码所示,我们需要修改模型导入的位置,然后修改线性层类别,特别需要注意的一点是在定义类别上categories,我们需要按照文件夹的顺序来填写。

from flask import Flask, request, jsonify, send_from_directory  
import torchvision
import torch
import torchvision.transforms as transforms
from PIL import Image
import io

# 初始化 Flask 应用
app = Flask(__name__)

# 设置设备为 GPU,如果不可用则使用 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义图像转换管道
transform = transforms.Compose([
    transforms.Resize(256),  # 将图像调整为 256x256
    transforms.CenterCrop(244),  # 裁剪中心的 244x244 区域
    transforms.ToTensor(),  # 将图像转换为 PyTorch 张量
    transforms.Normalize(mean=[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757])  # 归一化
])

# 加载 ResNet50 模型并修改为 7 个分类
model = torchvision.models.resnet50(weights=None)
model.fc = torch.nn.Linear(2048, 7)  # 设置输出层为 7 个类
model.load_state_dict(torch.load(r"E:\github\my_github\Flask_cat_7classify\best_model_train92.81.pth", map_location=device))
model.to(device)  # 将模型移动到指定设备
model.eval()  # 设置模型为评估模式

# 定义类别
categories = ['俄罗斯蓝猫', '孟买猫', '布偶猫', '暹罗猫', '波斯猫', '缅因猫', '英国短毛猫']

# 预测图像类别的函数
def predict_image(image_bytes):
    image = Image.open(io.BytesIO(image_bytes))  # 从字节加载图像
    image = transform(image).unsqueeze(0).to(device)  # 转换并添加批次维度
    with torch.no_grad():  # 禁用梯度计算
        output = model(image)  # 获取模型预测
        _, predicted = torch.max(output, 1)  # 获取预测的类别索引
        return predicted.item()  # 返回预测的类别索引

# 定义预测路由
@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': '没有文件部分'}), 400  # 如果没有文件部分,返回错误

    file = request.files['file']  # 获取上传的文件

    if file.filename == '':
        return jsonify({'error': '没有选择文件'}), 400  # 如果没有选择文件,返回错误

    if file:
        img_bytes = file.read()  # 读取文件字节
        try:
            prediction_index = predict_image(img_bytes)  # 获取预测索引
            prediction_label = categories[prediction_index]  # 将索引映射到标签
            return jsonify({'prediction': prediction_label})  # 返回预测结果
        except Exception as e:
            return jsonify({'error': str(e)}), 500  # 处理预测过程中出现的错误

# 用于服务静态文件的路由
@app.route('/<path:path>')
def send_static(path):
    return send_from_directory('templates', path)  # 从 templates 目录提供静态文件

# 运行 Flask 应用
if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)

另外在templates中定义了简单的静态html文件,如下图所示外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传最后运行app.py,然后在浏览器打开http://127.0.0.1:5000/index.html就可以成功实现
在这里插入图片描述

这样我们就实现了使用falsk部署深度学习模型的简易实现,在大佬文章中使用pyqt和gradio,感兴趣的可以实现一下。

上一篇:Android 重新定义一个广播修改系统时间,避免系统时间混乱


下一篇:磁盘的物理组成(Linux网络服务器 15)