这里只是一个走全程的代码,重在体验,如果想学习深度学习建议看
大体步骤是:
1.先处理数据,分训练集和测试集
2.构建模型
3.优化模型参数
4.保存模型
5.加载模型,测试
训练代码
# -*- coding: utf-8 -*-
# day day study day day up
# create by a Man
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from tqdm import tqdm
from torch import save
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
#1.数据加载
my_transforms=transforms.Compose(
[transforms.ToTensor(),#将图片变成张量
transforms.Normalize(mean=(0.1307,),std=(0.3081,)) #标准化处理
])
batch_size=64##batch_size每批要加载多少样本(默认值:1)
mnist_train = MNIST(root="../MNIST_data",
train=True, #训练集
download=True, #如果下了就设置为False
transform=my_transforms)
mnist_test=MNIST(root="../MNIST_test",
train=False,#测试集
download=True,#如果下了就设置为False
transform=my_transforms)
#2.构建模型
# 全连接层
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()
self.fc1 = nn.Linear(1*28*28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
self.relu = nn.ReLU()
self.fc2=nn.Linear(100,10)
def forward(self, image):
image_viwed = image.view(-1, 1*28*28) # 此处需要拍平
out = self.fc1(image_viwed)
fc1_out = self.relu(out)
out2=self.fc2(fc1_out)
return out2
#3.优化模型参数
def train(train_dataloader, model, loss_function, optimizer):
'''
训练
:param train_dataloader:
:param model:
:param loss_function:
:param optimizer:
:return:
'''
model.train()#必写,需要训练一次,不然报错
for (images, labels) in tqdm(train_dataloader,total=len(train_dataloader)):
images, labels = images.to(device), labels.to(device)
#梯度置零
optimizer.zero_grad()
#前向传播
output=model(images)
#通过结果计算损失
loss=loss_function(output,labels)
#反向传播
loss.backward()
#单次优化,优化器更新
optimizer.step()
def test(test_dataloader, model, loss_function):
'''
测试
:param test_dataloader:
:param model:
:param loss_function:
:return:
'''
model.eval()
size = len(test_dataloader.dataset)
num_batches = len(test_dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in test_dataloader:
X, y = X.to(device), y.to(device)
pred = model(X) #
test_loss += loss_function(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
model=MnistModel().to(device)#实例化模型并到gpu上运算
optimizer = optim.Adam(model.parameters())#优化器选择
loss_function= nn.CrossEntropyLoss()#选择交叉熵损失
train_dataloader = DataLoader(mnist_train, batch_size=batch_size,shuffle=True)#shuffle ( bool , optional ) – 设置为True在每个 epoch 重新洗牌数据(默认值:)False。
test_dataloader=DataLoader(mnist_test,batch_size=batch_size,shuffle=True)
epochs=10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_function, optimizer)
test(test_dataloader, model, loss_function)
print("Done!")
#4.保存
save(model.state_dict(),"minist.pkl")#存模型
save(optimizer.state_dict(),"optimizer.pkl")#存优化器
拿自己图片测试的代码
from torchvision import transforms
from torch import nn
import torch
from PIL import Image
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()#父类初始化命令行
self.fc1 = nn.Linear(1*28*28, 100) # 28是像素,最终为什么是 10,因为手写数字识别最终是10(0,1,2,3,4,5,6,7,8,9)分类的,分类任务中有多少,就分几类
self.relu = nn.ReLU()#激活函数
self.fc2 = nn.Linear(100, 10)#线性层
def forward(self, image):
image_viwed = image.view(-1, 1*28*28) #重点: 此处需要拍平
out_1 = self.fc1(image_viwed)
fc1 = self.relu(out_1)#激活函数
out_2 = self.fc2(fc1)
return out_2
model = MnistModel()
model.load_state_dict(torch.load("D:\pych\pytorch_test\minist.pkl"))#路径尽量写绝对路径吧
image = Image.open(r'D:\Desktop\5.jpg')#需要测试的图片路径
# print(image)<PIL.JpegImagePlugin.JpegImageFile image mode=RGB(三通道) size=224x205 at 0x19583197430>
my_transforms = transforms.Compose(
[
transforms.Grayscale(1),#通道变为1,因为图片之前是RGB三通道的
transforms.ToTensor(),#变张量
transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))#z-score 标准化,参数type为元组
]
)
image = my_transforms(image)
with torch.no_grad():#禁止梯度计算,因为测试效果不需要
pred = model(image)
result=pred.max(dim=1).indices
print(result)#