说明:
- 一直没搞明白PR曲线老师说的“对预测结果进行排序,排在最前面的是模型认为最可能为正例的样本,排在最后的是模型认为最不可能为正例的样本,按此顺序逐个把样本作为正例进行预测,每次都可以得出查准率和查全率”,通过查询和看别人示例终于明白了,这其实是一个预测过程,将每个实例预测结果作为阈值对所有样本进行预测计算查准率和查全率。
- PR曲线是由模型的查准率和查全率为坐标轴形成的曲线,查准率P为纵坐标 查全率R为横坐标
- P 查准率:在二分类问题中所有预测为正向的样本中真正为正向样本的比例 P=TP/(TP+FP)
- R 查全率:在二分类问题中所有正向样本中被正确预测的样本的比例 R=TP/(TP+FN)
- TP:真正例 FP:假正例 TN:真反例 FN:假反例
- 用于生成PR曲线的数据为随机数据,不能代表真正模型预测评估,只用于完成PR曲线
导入包
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
产生测试数据
# 产生两组 0到1之间的随机数
# 演示数据1
rand_1 = list(np.random.random(20))
# 生成标签和预测概率数据
test_1 = []
for i in range(20):
label = "P" if i < 10 else "N"
test_1.append({"value": rand_1[i], "label": label})
# 对概率进行排序
rand_1.sort(reverse=True)
# 演示数据2
rand_2 = list(np.random.random(20))
# 生成标签和预测概率数据
test_2 = []
for i in range(20):
label = "P" if i < 10 else "N"
test_2.append({"value": rand_2[i], "label": label})
# 对概率进行排序
rand_2.sort(reverse=True)
计算PR值
# 计算PR值
# values 模型预测的所有样本为正的概率列表
# data 模型预测的数据与样本自身正确标签
def get_pr(values=[], datas=[]):
pr = []
for value in values:
counts = {"TP": 0, "FP": 0, "TN": 0, "FN": 0}
for data in datas:
predict_label = "P" if data["value"] >= value else "N"
if predict_label == "P" and data["label"] == "P":
counts["TP"] += 1
elif predict_label == "P" and data["label"] == "N":
counts["FP"] += 1
elif predict_label == "N" and data["label"] == "N":
counts["TN"] += 1
elif predict_label == "N" and data["label"] == "P":
counts["FN"] += 1
# 计算查准率
p = round(counts["TP"]/(counts["TP"]+counts["FP"]), 2)
# 计算查全率
r = round(counts["TP"]/(counts["TP"]+counts["FN"]), 2)
pr.append({"p": p, "r": r})
return pr
组合数据 用于绘制图表
pr_1 = get_pr(rand_1, test_1)
pr_2 = get_pr(rand_2, test_2)
# 生成展示数据
data_show = []
for pr in pr_1:
data_show.append({'p': pr['p'], 'r': pr['r'], 'model': 'model_1'})
for pr in pr_2:
data_show.append({'p': pr['p'], 'r': pr['r'], 'model': 'model_2'})
for pr in range(20):
value = (1.0/20)*pr
data_show.append({'p': value, 'r': value, 'model': 'BEP'})
data_show = pd.DataFrame(data_show)
绘制图表
sns.relplot(x="r", y="p", ci=None, hue='model', kind="line", data=data_show);