在获得模型推理结果后,我们需要对其进行评估,以衡量分类的准确性。
eval.ipynb
from settings import LABEL_NAMES
from utils import load_obj
from datasets import load_dataset
from settings import data_files, output_dirs
import os
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'
ds = load_dataset("fancyzhx/ag_news")
def eval(raw_dataset, vllm_predict):
right = 0 # 预测正确的数量
multi_label = 0 # 预测多标签的数量
for data, output in zip(raw_dataset, vllm_predict):
true_label = LABEL_NAMES[data['label']]
output_text = output.outputs[0].text
pred_label = output_text.split("label")[-1]
tmp_pred = []
for label in LABEL_NAMES:
if label in pred_label:
tmp_pred.append(label)
if len(tmp_pred) > 1:
multi_label += 1
if " ".join(tmp_pred) == true_label:
right += 1
return right, multi_label
我们分别对 basic 和 reason 预测结果进行了评估。
basic 预测结果的评估 :
dataset = load_dataset(
'csv',
data_files=data_files[0],
split='train'
)
output = load_obj(output_dirs[0])
eval(dataset, output)
输出结果:
(5845, 143)
加了reason 预测结果评估:
dataset = load_dataset(
'csv',
data_files=data_files[1],
split='train'
)
output = load_obj(output_dirs[1])
eval(dataset, output)
输出结果:
(6293, 14)
评估结果如下:
-
basic: 直接分类准确率为 77%(5845/7600),误分类为多标签的样本有 143 个。
-
reason: 在输出原因后分类准确率提高至 83%(6293/7600),多标签误分类样本减少至 14 个。
误分类多标签: 这是单分类问题,大模型应该只输出一个类别,但是它输出了多个类别;
可以发现,让大模型输出reason,其分类准确率提升了5%。
在误分类多标签的数量也有所下降。原先误分类多标签有143条数据,使用reason后,多标签误分类的数量降低到了14条。
这些结果表明,让模型输出 reason的过程,确实能够有效提升分类准确性,并减少误分类多个标签的情况。