vllm 部署GLM4模型进行 Zero-Shot 文本分类实验,让大模型给出分类原因,准确率可提高6%-评估

在获得模型推理结果后,我们需要对其进行评估,以衡量分类的准确性。

eval.ipynb

from settings import LABEL_NAMES  
from utils import load_obj  
  
from datasets import load_dataset  
from settings import data_files, output_dirs  
  
import os  
os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'  
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'  
  
ds = load_dataset("fancyzhx/ag_news")  

def eval(raw_dataset, vllm_predict):  
  
    right = 0 # 预测正确的数量  
    multi_label = 0 # 预测多标签的数量  
  
    for data, output in zip(raw_dataset, vllm_predict):  
        true_label = LABEL_NAMES[data['label']]  
  
        output_text = output.outputs[0].text  
        pred_label = output_text.split("label")[-1]  
  
        tmp_pred = []  
        for label in LABEL_NAMES:  
            if label in pred_label:  
                tmp_pred.append(label)  
  
        if len(tmp_pred) > 1:  
            multi_label += 1  
  
        if " ".join(tmp_pred) == true_label:  
            right += 1  
  
    return right, multi_label  

我们分别对 basic 和 reason 预测结果进行了评估。

basic 预测结果的评估 :

dataset = load_dataset(  
    'csv',   
    data_files=data_files[0],   
    split='train'  
    )  
output = load_obj(output_dirs[0])  
  
eval(dataset, output)  

输出结果:

(5845, 143)  

加了reason 预测结果评估:

dataset = load_dataset(  
    'csv',   
    data_files=data_files[1],   
    split='train'  
    )  
output = load_obj(output_dirs[1])  
  
eval(dataset, output)  

输出结果:

(6293, 14)  

评估结果如下:

  • basic: 直接分类准确率为 77%(5845/7600),误分类为多标签的样本有 143 个。

  • reason: 在输出原因后分类准确率提高至 83%(6293/7600),多标签误分类样本减少至 14 个。

误分类多标签: 这是单分类问题,大模型应该只输出一个类别,但是它输出了多个类别;

可以发现,让大模型输出reason,其分类准确率提升了5%。

在误分类多标签的数量也有所下降。原先误分类多标签有143条数据,使用reason后,多标签误分类的数量降低到了14条。

这些结果表明,让模型输出 reason的过程,确实能够有效提升分类准确性,并减少误分类多个标签的情况。

上一篇:QD1-P13 HTML 表单标签(form)


下一篇:嵌入式工业显示器在食品生产行业的应用