文章目录
前言
本项目基于pytorch进行深度学习神经网络搭建、flask框架作为web应用开发,搭建了一个简单的手写数字识别项目,用户可通过在网页上通过拖动鼠标手写数字,通过点击预测进行模型推理,最后显示结果。本文项目链接为:https://github.com/Windxy/torch_mnist_flask
一、环境配置
- Flask == 1.1.2
- torch == 1.2.0
- torchvision == 0.4.0
- tqdm == 4.42.1
二、使用步骤
1.项目克隆
打开cmd,cd到你需要的目录下,然后输入git clone https://github.com/Windxy/torch_mnist_flask
或直接打开网页进行下载
2.下载数据集
相关MNIST数据集在百度网盘链接,提取码:vnyc,然后将其解压在DataSet目录下
3.模型训练
运行train.py训练MNIST数据,使用的SGD,学习率0.0001,100次迭代,loss为交叉熵损失函数,根据自己的要求,可以改变网络结构、优化器、自适应学习率、预处理方式等等。
from model.model import Model
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
if __name__ == '__main__':
batch_size = 256
train_dataset = mnist.MNIST(root='./DataSet', train=True, transform=ToTensor())
test_dataset = mnist.MNIST(root='./DataSet', train=False, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model()
model.load_state_dict(torch.load('model/mnist.pth'))
optimizer = SGD(model.parameters(), lr=1e-4)
cross_error = CrossEntropyLoss()
epoch = 100
for _epoch in range(epoch):
for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10))
optimizer.zero_grad()
out_of_predict = model(train_x.float())
loss = cross_error(out_of_predict, train_label.long())
if idx % 10 == 0:
print('idx: {}, _error: {}'.format(idx, loss))
loss.backward()
optimizer.step()
correct = 0
_sum = 0
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = model(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
print('accuracy: {:.2f}'.format(correct / _sum))
torch.save(model.state_dict(), 'model/mnist.pth')
4.模型测试
from model.model import Model
import numpy as np
import torch
from torchvision.datasets import mnist
from tqdm import tqdm,trange
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve,roc_curve,f1_score,precision_recall_fscore_support
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range
if __name__ == '__main__':
batch_size = 1 # 一个一个测
test_dataset = mnist.MNIST(root='./DataSet', train=False, transform=ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size)
model = Model()
model.load_state_dict(torch.load('model/mnist.pth'))
threshold = 0.00001
# 初始化
# 1w个数据,980个0
accurance = 0 #准确率
correct = 0 #正确的数量
nums = 0 #0的个数
y_scores = []
y_true = []
y_pred = []
for (test_x, test_label) in tqdm(test_loader):
'''在这里,取用0作为我们的正例,1—9作为我们的反例'''
predict_y = model(test_x.float()).detach()
predict_ys = predict_y.numpy().squeeze()
predict_ys = normalization(predict_ys)[0] # 是0的概率
y_scores.append(predict_ys)
if test_label.numpy()==0:
nums+=1
correct += 1 if np.argmax(predict_y, axis=-1).item() == 0 else 0
y_true.append(1 if test_label.numpy()[0]==0 else 0) #还是用0来进行PR评估,1 表示为 0,0 表示为非 0
y_pred.append(1 if np.argmax(predict_y, axis=-1).item()==0 else 0)
accurance = correct*1.0/nums #980个0
print("Accuracy:",accurance)
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
fpr, tpr, tresholds = roc_curve(y_true, y_scores)
F1= f1_score(y_true, y_pred)
print("F1-Score:",F1)
plt.figure(figsize=(50,100))
plt.subplot(1,2,1)
plt.plot(precision, recall)
plt.xlabel(r'Recall') # 坐标
plt.ylabel(r'Precision')
plt.title("figure of PR-Curve")
plt.subplot(1,2,2)
plt.plot(fpr, tpr)
plt.title("figure of ROC")
plt.xlabel(r'False Positive Rate') # 坐标
plt.ylabel(r'True Positive Rate')
plt.show()
Accuracy: 0.9959183673469387
F1-Score: 0.9928789420142421
5.flask部署
运行app.py即可,打开出现的链接,flask网页设置参考的是:
https://blog.csdn.net/qq_38534107/article/details/103565899
总结
1.根据自己的需求,可以在model.py、train.py更改自己想要的网络、数据、训练参数等
2.根据自己的需求,可以在html和js部分更改需要的网页布局和操作
3.本文演示了如何基于pytorch和flask进行模型的训练、评估和网页端手写数字的识别