ResNet残差网络Pytorch实现——对花的种类进行训练
上一篇:【课程1 - 第二周作业】 ✌✌✌✌ 【目录】 ✌✌✌✌ 下一篇:【课程1 - 第三周作业】
大学生一枚,最近在学习神经网络,写这篇文章只是记录自己的学习历程,本文参考了Github上fengdu78老师的文章进行学习
✌ 使用ResNet进行对花的种类进行训练
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
# 加载设备,使用cpu还是显卡
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('using {} device'.format(device))
# 图片处理,对应train和验证集
data_transform={
'train':transforms.Compose([transforms.RandomResizedCrop(224), # 将图片裁剪为224*224
transforms.RandomHorizontalFlip(), # 将图片随机反转
transforms.ToTensor(), # 转化为ToTensor
# 进行标准化
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
'val':transforms.Compose([transforms.Resize(256), # 调整图片大小
transforms.CenterCrop(224), # 中心裁剪224*224
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
# os.getcwd()获得当前位置的绝对路径
# os.path.join()将两个路径进行拼接
img_path=os.path.join(os.getcwd(),'flower_data')
# 加载训练数据,同时对需要训练的图片进行处理
train_dataset=datasets.ImageFolder(root=os.path.join(img_path,'train'),
transform=data_transform['train'])
# 加载验证集
val_dataset=datasets.ImageFolder(root=os.path.join(img_path,'val'),
transform=data_transform['val'])
# 定义每个训练批次的数据数量,对应每次训练16张图片
batch_size=16
# 训练数据的加载器
# 根据批次大小进行将数据进行分批
# 一般来说训练数据需要打乱,而验证集不需要
train_loader=torch.utils.data.DataLoader(train_dataset,
batch_size,
shuffle=True)
# 验证数据的加载器
val_loader=torch.utils.data.DataLoader(val_dataset,
batch_size,
shuffle=False)
# 训练和验证的数据大小
train_num=len(train_dataset)
val_num=len(val_dataset)
print('using {} images for training, {} images for validation.'.format(train_num,val_num))
# train_dataset.class_to_idx会返回'A':0,'B':1,'C':2,即每个类别对应的数值映射
flower_list=train_dataset.class_to_idx
# 将其逆置,为了预测时根据预测结果的分类找出对应的字符真实分类,如果不做,最终预测只知道是0,1,2这种,不知道花的真实类别
# data_dataset加载图片会根据图片所在的文件夹确定其分类
cla_dict=dict((value,key) for key,value in flower_list.items())
# 将字典转成json串
json_str=json.dumps(cla_dict,indent=4)
# 将json串写入到文件
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
# 创建网络
net=resnet34()
# 模型参数路径
model_weight_path='./resnet34-pre.pth'
# 加载模型参数
net.load_state_dict(torch.load(model_weight_path,map_location=device))
# 取出全连接层进行替换,因为模型中默认是1000分类,而本题中是5分类,
# 所以要取出全连接层获得全连接层的输入层加上现在的新输出5分类,构建新的全连接层
in_channel=net.fc.in_features
net.fc=nn.Linear(in_channel,5)
net.to(device)
# 交叉熵损失函数
loss_function=nn.CrossEntropyLoss()
# 获得可以更新梯度的参数
params=[p for p in net.parameters() if p.requires_grad]
# 定义优化器
optimizer=optim.Adam(params,lr=0.0001)
# 训练轮数,将所有的数据训练epochs次
epochs=3
# 最好的准确度
best_acc=0
# 训练的模型参数路径
save_path='./resNet34.pth'
# 训练集的批数
train_steps=len(train_loader)
for epoch in range(epochs):
# 开启训练模式
net.train()
running_loss=0
train_bar=tqdm(train_loader)
for data in train_bar:
# 获得每个训练批的数据和标签
images,labels=data
optimizer.zero_grad()
output=net(images.to(device))
loss=loss_function(output,labels.to(device))
loss.backward()
optimizer.step()
running_loss+=loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
epochs,
loss)
# 开启验证模式
net.eval()
acc=0
# 不需要进行梯度下降求导
with torch.no_grad():
val_bar=tqdm(val_loader)
for data in val_bar:
images,labels=data
output=net(images.to(device))
# touch.max()返回指定维度的最大值和该值所在的索引
y_pred=torch.max(output,dim=1)[1]
# 计算预测正确的个数
acc+=torch.eq(y_pred,labels.to(device)).sum().item()
val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
epochs)
val_accurate=acc/val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
# 如果当前的准确率>最高的准确率就将其取代,保存当前训练的参数
if val_accurate>best_acc:
best_acc=val_accurate
torch.save(net.state_dict(),save_path)