【深度学习笔记】-代码解读7-模型训练、验证与可视化

转载「深度学习一遍过」必修7:模型训练、验证与可视化_荣仔的博客-CSDN博客

1 Create Dataset

  • 生成训练集和测试集 
  • 生成验证集
import os
import random#打乱数据用的
 
def CreateEvalData():
    data_list = []
    test_root = r".\testdata"

    for a, b, c in os.walk(test_root):
        for i in range(len(c)):
            data_list.append(os.path.join(a, c[i]))
    print(data_list)

    with open('eval.txt', 'w', encoding='UTF-8') as f:
        for test_img in data_list:
            f.write(test_img + '\t' + "0" + '\n')
 
if __name__ == "__main__":
    CreateEvalData()

2 模型训练-

纪录训练信息,包括: train loss;test loss; test accurac

  • 都进行微调
  • 只微调最后
  • 从头开始训练不微调

3 模型验证

 

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from utils import LoadData, write_result
import pandas as pd
 
 
def eval(dataloader, model):
    label_list = []
    likelihood_list = []
    model.eval()
    with torch.no_grad():
        # 加载数据加载器,得到里面的X(图片数据)和y(真实标签)
        for X, y in dataloader:
            # 将数据转到GPU
            X = X.cuda()
            # 将图片传入到模型当中就,得到预测的值pred
            pred = model(X)
            # 获取可能性最大的标签
            label = torch.softmax(pred,1).cpu().numpy().argmax()
            label_list.append(label)
            # 获取可能性最大的值(即概率)
            likelihood = torch.softmax(pred,1).cpu().numpy().max()
            likelihood_list.append(likelihood)
        return label_list,likelihood_list
 
 
if __name__ == "__main__":
 
    '''
    1. 导入模型结构
    '''
    model = resnet18(pretrained=False)
    num_ftrs = model.fc.in_features    # 获取全连接层的输入
    model.fc = nn.Linear(num_ftrs, 5)  # 全连接层改为不同的输出
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
 
    '''
    2. 加载模型参数
    '''
    model_loc = "./BEST_resnet_epoch_10_acc_80.9.pth"
    model_dict = torch.load(model_loc)
    model.load_state_dict(model_dict)
    model = model.to(device)
 
    '''
    3. 加载图片
    '''
 
    valid_data = LoadData("eval.txt", train_flag=False)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=2, pin_memory=True, batch_size=1)
 
 
    '''
    4. 获取结果
    '''
    label_list, likelihood_list =  eval(test_dataloader, model)
    label_names = ["daisy", "dandelion","rose","sunflower","tulip"]
 
    result_names = [label_names[i] for i in label_list]
 
    list = [result_names, likelihood_list]
    df = pd.DataFrame(data=list)
    df2 = pd.DataFrame(df.values.T, columns=["label", "likelihood"])
    print(df2)
    df2.to_csv('testdata.csv', encoding='gbk')

4 可视化 

  •  混淆矩阵、召回率、精准率、ROC曲线可视化 
'''
    模型性能度量
'''
from sklearn.metrics import *  # pip install scikit-learn
import matplotlib.pyplot as plt # pip install matplotlib
import numpy as np  # pip install numpy
from numpy import interp
from sklearn.preprocessing import label_binarize
import pandas as pd # pip install pandas
 
'''
读取数据
需要读取模型输出的标签(predict_label)以及原本的标签(true_label)
'''
target_loc = "test.txt"
target_data = pd.read_csv(target_loc, sep="\t", names=["loc","type"])
true_label = [i for i in target_data["type"]]
 
predict_loc = "testdata.csv"
predict_data = pd.read_csv(predict_loc)#,index_col=0)
 
predict_label = predict_data.to_numpy().argmax(axis=1)
predict_score = predict_data.to_numpy().max(axis=1)
print("predict_score = ",predict_score )
 
 
'''
    常用指标:精度,查准率,召回率,F1-Score
'''
# 精度,准确率, 预测正确的占所有样本种的比例
accuracy = accuracy_score(true_label, predict_label)
print("精度: ",accuracy)
 
# 查准率P(准确率),precision(查准率)=TP/(TP+FP)
precision = precision_score(true_label, predict_label, labels=None, pos_label=1, average='macro') # 'micro', 'macro', 'weighted'
print("查准率P: ",precision)
 
# 查全率R(召回率),原本为对的,预测正确的比例;recall(查全率)=TP/(TP+FN)
recall = recall_score(true_label, predict_label, average='micro') # 'micro', 'macro', 'weighted'
print("召回率: ",recall)
 
# F1-Score
f1 = f1_score(true_label, predict_label, average='micro')     # 'micro', 'macro', 'weighted'
print("F1 Score: ",f1)
 
 
'''
混淆矩阵
'''
label_names = ["daisy", "dandelion","rose","sunflower","tulip"]
confusion = confusion_matrix(true_label, predict_label, labels=[i for i in range(len(label_names))])
 
# print("混淆矩阵: \n",confusion)
 
plt.matshow(confusion, cmap=plt.cm.Oranges)   # Greens, Blues, Oranges, Reds
plt.colorbar()
for i in range(len(confusion)):
    for j in range(len(confusion)):
        plt.annotate(confusion[i,j], xy=(i, j), horizontalalignment='center', verticalalignment='center')
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.xticks(range(len(label_names)), label_names)
plt.yticks(range(len(label_names)), label_names)
plt.title("Confusion Matrix")
plt.show()
 
 
'''
ROC曲线(多分类)
'''
n_classes = len(label_names)
binarize_predict = label_binarize(predict_label, classes=[i for i in range(n_classes)])
 
# 读取预测结果
predict_score = predict_data.to_numpy()
 
# 计算每一类的ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(binarize_predict[:,i], [socre_i[i] for socre_i in predict_score])
    roc_auc[i] = auc(fpr[i], tpr[i])
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
# Plot all ROC curves
lw = 2
plt.figure()
plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)
 
 
for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], lw=lw, label='ROC curve of {0} (area = {1:0.2f})'.format(label_names[i], roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class receiver operating characteristic ')
plt.legend(loc="lower right")
plt.show()

上一篇:【报错记录】解决SpringBoot使用knife4j无法引入@EnableSwagger2WebMvc


下一篇:sklearn 估计器(estimator)接口