本文主要介绍一个框架nlp-basictasks
nlp-basictasks是利用PyTorch深度学习框架所构建一个简单的库,旨在快速搭建模型完成一些基础的NLP任务,如分类、匹配、序列标注、语义相似度计算等。
下面利用该框架实现问答检索
所谓检索,就是将本地所有待检索的每一个句子编码成一个向量,将这些向量存储起来,称之为索引index。然后针对用户的问题query,将其编码为一个对应的向量query_vector,计算query_vector和索引中每一个vector的相似度,得到相似度最高的几个vector,最后返回这些vector对应的问题。
此外还需要一个变量query2id,
query2id指的是每一个问题和对应的id之间的映射,这个id就是索引中这个问题对应的向量的id,因为当我们计算完相似度最高的几个向量后,还要根据向量在索引中的id从query2id中找出对应的问题
数据集介绍
实验所用数据集是常用的中文自然语言推理数据集lcqmc,来源http://icrc.hitsz.edu.cn/Article/show/171.html
all_sentences=[]
data_folder='lcqmc/lcqmc_test.tsv'#我们只取少量的test.tsv实验
with open(data_folder) as f:
lines=f.readlines()
for line in lines[1:]:
line_split=line.strip().split('\t')
all_sentences.append(line_split[0])
all_sentences.append(line_split[1])
导入包
import sys,os
from nlp_basictasks.webservices.sts_retrieve import RetrieveModel
import nlp_basictasks
定义路径加载模型
model_name_or_path='' #model_name_or_path指的就是你下载的BERT模型存放的路径 如:'chinese-roberta-wwm/'
save_index_path='' #这个变量代表你所要存储的索引的路径,如:'faiss_index.index'
save_query2id_path='' #这个变量代表你要存储的query2id的路径,如:'query2id.json'。
IR_model=RetrieveModel(save_index_path=save_index_path,
save_query2id_path=save_query2id_path,
encode_dim=768,
model_name_or_path=model_name_or_path)
建立索引
IR_model.createIndex(all_sentences)
问答检索
import time
start_time=time.time()
result=IR_model.retrieval("那个人正在玩电子游戏",topk=10)#检索回10个最相似的问题
end_time=time.time()
print("从%d个问题中检索一个问题需要%f ms"%(IR_model.index.ntotal,(end_time-start_time)*1000))
此外还支持动态的向索引中添加句子和删除索引中的句子,相关细节见nlp-basictasks框架做问答检索
不到50行代码即可实现问答检索,觉得好用的话还请点个star,谢谢。