BERT结合Faiss实现问答检索(少于50行代码)

本文主要介绍一个框架nlp-basictasks
nlp-basictasks是利用PyTorch深度学习框架所构建一个简单的库,旨在快速搭建模型完成一些基础的NLP任务,如分类、匹配、序列标注、语义相似度计算等。

下面利用该框架实现问答检索

所谓检索,就是将本地所有待检索的每一个句子编码成一个向量,将这些向量存储起来,称之为索引index。然后针对用户的问题query,将其编码为一个对应的向量query_vector,计算query_vector和索引中每一个vector的相似度,得到相似度最高的几个vector,最后返回这些vector对应的问题。
此外还需要一个变量query2id,
query2id指的是每一个问题和对应的id之间的映射,这个id就是索引中这个问题对应的向量的id,因为当我们计算完相似度最高的几个向量后,还要根据向量在索引中的id从query2id中找出对应的问题

数据集介绍

实验所用数据集是常用的中文自然语言推理数据集lcqmc,来源http://icrc.hitsz.edu.cn/Article/show/171.html

all_sentences=[]
data_folder='lcqmc/lcqmc_test.tsv'#我们只取少量的test.tsv实验
with open(data_folder) as f:
    lines=f.readlines()
    for line in lines[1:]:
        line_split=line.strip().split('\t')
        all_sentences.append(line_split[0])
        all_sentences.append(line_split[1])

导入包

import sys,os
from nlp_basictasks.webservices.sts_retrieve import RetrieveModel
import nlp_basictasks

定义路径加载模型

model_name_or_path='' #model_name_or_path指的就是你下载的BERT模型存放的路径 如:'chinese-roberta-wwm/'
save_index_path='' #这个变量代表你所要存储的索引的路径,如:'faiss_index.index'
save_query2id_path='' #这个变量代表你要存储的query2id的路径,如:'query2id.json'。
IR_model=RetrieveModel(save_index_path=save_index_path,
                       save_query2id_path=save_query2id_path,
                       encode_dim=768,
                      model_name_or_path=model_name_or_path)

建立索引

IR_model.createIndex(all_sentences)

问答检索

import time
start_time=time.time()
result=IR_model.retrieval("那个人正在玩电子游戏",topk=10)#检索回10个最相似的问题
end_time=time.time()
print("从%d个问题中检索一个问题需要%f ms"%(IR_model.index.ntotal,(end_time-start_time)*1000))

BERT结合Faiss实现问答检索(少于50行代码)
此外还支持动态的向索引中添加句子和删除索引中的句子,相关细节见nlp-basictasks框架做问答检索

不到50行代码即可实现问答检索,觉得好用的话还请点个star,谢谢。

上一篇:Sentence-BERT论文阅读笔记


下一篇:MindSpore21天实战营手记(二) :基于Bert进行中文新闻分类