NLP(三十三):sentence-transformers句子相似度官方示例

一、出处

https://www.sbert.net/examples/training/sts/README.html

https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py

二、代码

此示例从头开始为 STSbenchmark 训练 BERT(或任何其他转换器模型,如 RoBERTa、DistilBERT 等)。 它生成句子嵌入,可以使用余弦相似度进行比较以测量相似度。

用法:

python training_nli.py 或者 python training_nli.py pretrained_transformer_model_name

NLP(三十三):sentence-transformers句子相似度官方示例
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv

#### Just some code to print debug information to stdout
只是一些将调试信息打印到标准输出的代码
logging.basicConfig(format=‘%(asctime)s - %(message)s‘, datefmt=‘%Y-%m-%d %H:%M:%S‘, level=logging.INFO, handlers=[LoggingHandler()]) #### /print debug information to stdout 调试信息打印到标准输出 #Check if dataset exsist. If not, download and extract it 检查数据集是否存在。 如果没有,请下载并解压 sts_dataset_path = ‘datasets/stsbenchmark.tsv.gz‘ if not os.path.exists(sts_dataset_path): util.http_get(‘https://sbert.net/datasets/stsbenchmark.tsv.gz‘, sts_dataset_path) #You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
您可以在此处指定任何 Huggingface/transformers 预训练模型,例如,bert-base-uncased、roberta-base、xlm-roberta-base
model_name = sys.argv[1] if len(sys.argv) > 1 else ‘distilbert-base-uncased‘ # Read the dataset 读取数据集 train_batch_size = 16 num_epochs = 4 model_save_path = ‘output/training_stsbenchmark_‘+model_name.replace("/", "-")+‘-‘+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
使用 Huggingface/transformers 模型(如 BERT、RoBERTa、XLNet、XLM-R)将令牌映射到嵌入
word_embedding_model = models.Transformer(model_name) # Apply mean pooling to get one fixed sized sentence vector 应用平均池化得到一个固定大小的句子向量 pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) # Convert the dataset to a DataLoader ready for training 将数据集转换为准备训练的 DataLoader logging.info("Read STSbenchmark train dataset") train_samples = [] dev_samples = [] test_samples = [] with gzip.open(sts_dataset_path, ‘rt‘, encoding=‘utf8‘) as fIn: reader = csv.DictReader(fIn, delimiter=‘\t‘, quoting=csv.QUOTE_NONE) for row in reader: score = float(row[‘score‘]) / 5.0 # Normalize score to range 0 ... 1 inp_example = InputExample(texts=[row[‘sentence1‘], row[‘sentence2‘]], label=score) if row[‘split‘] == ‘dev‘: dev_samples.append(inp_example) elif row[‘split‘] == ‘test‘: test_samples.append(inp_example) else: train_samples.append(inp_example) train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) train_loss = losses.CosineSimilarityLoss(model=model) logging.info("Read STSbenchmark dev dataset") evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name=‘sts-dev‘) # Configure the training. We skip evaluation in this example 配置训练。 我们在这个例子中跳过评估 warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up logging.info("Warmup-steps: {}".format(warmup_steps)) # Train the model model.fit(train_objectives=[(train_dataloader, train_loss)], evaluator=evaluator, epochs=num_epochs, evaluation_steps=1000, warmup_steps=warmup_steps, output_path=model_save_path) ############################################################################## # # Load the stored model and evaluate its performance on STS benchmark dataset 加载存储的模型并评估其在 STS 基准数据集上的性能 # ############################################################################## model = SentenceTransformer(model_save_path) test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name=‘sts-test‘) test_evaluator(model, output_path=model_save_path)
NLP(三十三):sentence-transformers句子相似度官方示例

 三、评估

NLP(三十三):sentence-transformers句子相似度官方示例
"""
This examples loads a pre-trained model and evaluates it on the STSbenchmark dataset
此示例加载预训练模型并在 STSbenchmark 数据集上对其进行评估 Usage: python evaluation_stsbenchmark.py OR python evaluation_stsbenchmark.py model_name """ from sentence_transformers import SentenceTransformer, util, LoggingHandler, InputExample from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator import logging import sys import torch import gzip import os import csv script_folder_path = os.path.dirname(os.path.realpath(__file__)) #Limit torch to 4 threads 将割炬限制为 4 个线程 torch.set_num_threads(4) #### Just some code to print debug information to stdout logging.basicConfig(format=‘%(asctime)s - %(message)s‘, datefmt=‘%Y-%m-%d %H:%M:%S‘, level=logging.INFO, handlers=[LoggingHandler()]) #### /print debug information to stdout model_name = sys.argv[1] if len(sys.argv) > 1 else ‘stsb-distilroberta-base-v2‘ # Load a named sentence model (based on BERT). This will download the model from our server. # Alternatively, you can also pass a filepath to SentenceTransformer()
加载命名句子模型(基于 BERT)。 这将从我们的服务器下载模型。 或者,您也可以将文件路径传递给 SentenceTransformer()
model = SentenceTransformer(model_name) sts_dataset_path = ‘data/stsbenchmark.tsv.gz‘ if not os.path.exists(sts_dataset_path): util.http_get(‘https://sbert.net/datasets/stsbenchmark.tsv.gz‘, sts_dataset_path) train_samples = [] dev_samples = [] test_samples = [] with gzip.open(sts_dataset_path, ‘rt‘, encoding=‘utf8‘) as fIn: reader = csv.DictReader(fIn, delimiter=‘\t‘, quoting=csv.QUOTE_NONE) for row in reader: score = float(row[‘score‘]) / 5.0 # Normalize score to range 0 ... 1 inp_example = InputExample(texts=[row[‘sentence1‘], row[‘sentence2‘]], label=score) if row[‘split‘] == ‘dev‘: dev_samples.append(inp_example) elif row[‘split‘] == ‘test‘: test_samples.append(inp_example) else: train_samples.append(inp_example) evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name=‘sts-dev‘) model.evaluate(evaluator) evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name=‘sts-test‘) model.evaluate(evaluator)
NLP(三十三):sentence-transformers句子相似度官方示例

 四、继续训练

NLP(三十三):sentence-transformers句子相似度官方示例
"""
This example loads the pre-trained SentenceTransformer model ‘nli-distilroberta-base-v2‘ from the server.
It then fine-tunes this model for some epochs on the STS benchmark dataset.
Note: In this example, you must specify a SentenceTransformer model.
If you want to fine-tune a huggingface/transformers model like bert-base-uncased, see training_nli.py and training_stsbenchmark.py

此示例从服务器加载预训练的 SentenceTransformer 模型“nli-distilroberta-base-v2”。
然后,它针对 STS 基准数据集上的某些时期对该模型进行微调。
注意:在此示例中,您必须指定 SentenceTransformer 模型。
如果你想微调像 bert-base-uncased 这样的拥抱脸/变形金刚模型,请参阅 training_nli.py 和 training_stsbenchmark.py

"""
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import logging
from datetime import datetime
import os
import gzip
import csv

#### Just some code to print debug information to stdout
logging.basicConfig(format=‘%(asctime)s - %(message)s‘,
                    datefmt=‘%Y-%m-%d %H:%M:%S‘,
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#Check if dataset exsist. If not, download and extract  it
sts_dataset_path = ‘datasets/stsbenchmark.tsv.gz‘

if not os.path.exists(sts_dataset_path):
    util.http_get(‘https://sbert.net/datasets/stsbenchmark.tsv.gz‘, sts_dataset_path)




# Read the dataset
model_name = ‘nli-distilroberta-base-v2‘
train_batch_size = 16
num_epochs = 4
model_save_path = ‘output/training_stsbenchmark_continue_training-‘+model_name+‘-‘+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")



# Load a pre-trained sentence transformer model 将数据集转换为准备训练的 DataLoader
model = SentenceTransformer(model_name)

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, ‘rt‘, encoding=‘utf8‘) as fIn:
    reader = csv.DictReader(fIn, delimiter=‘\t‘, quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row[‘score‘]) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row[‘sentence1‘], row[‘sentence2‘]], label=score)

        if row[‘split‘] == ‘dev‘:
            dev_samples.append(inp_example)
        elif row[‘split‘] == ‘test‘:
            test_samples.append(inp_example)
        else:
            train_samples.append(inp_example)



train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)


# Development set: Measure correlation between cosine score and gold labels 开发集:测量余弦分数和黄金标签之间的相关性
logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name=‘sts-dev‘)


# Configure the training. We skip evaluation in this example 配置训练。 我们在这个例子中跳过评估
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up 10% 的列车数据用于热身
logging.info("Warmup-steps: {}".format(warmup_steps))


# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)


##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
加载存储的模型并评估其在 STS 基准数据集上的性能
# ############################################################################## model = SentenceTransformer(model_save_path) test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name=‘sts-test‘) test_evaluator(model, output_path=model_save_path)
NLP(三十三):sentence-transformers句子相似度官方示例

 五、Faiss语义检索

NLP(三十三):sentence-transformers句子相似度官方示例
"""
This example uses Approximate Nearest Neighbor Search (ANN) with FAISS (https://github.com/facebookresearch/faiss).
Searching a large corpus with Millions of embeddings can be time-consuming. To speed this up,
ANN can index the existent vectors. For a new query vector, this index can be used to find the nearest neighbors.
This nearest neighbor search is not perfect, i.e., it might not perfectly find all top-k nearest neighbors.
In this example, we use FAISS with an inverse flat index (IndexIVFFlat). It learns to partition the corpus embeddings
into different cluster (number is defined by n_clusters). At search time, the matching cluster for query is found and only vectors
in this cluster must be search for nearest neighbors.
This script will compare the result from ANN with exact nearest neighbor search and output a Recall@k value
as well as the missing results in the top-k hits list.
See the FAISS repository, how to install FAISS.
As dataset, we use the Quora Duplicate Questions dataset, which contains about 500k questions (only 100k are used):
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs.
As embeddings model, we use the SBERT model ‘quora-distilbert-multilingual‘,
that it aligned for 100 languages. I.e., you can type in a question in various languages and it will
return the closest questions in the corpus (questions in the corpus are mainly in English).

此示例使用带有 FAISS (https://github.com/facebookresearch/faiss) 的近似最近邻搜索 (ANN)。
搜索具有数百万个嵌入的大型语料库可能非常耗时。为了加快速度,
ANN 可以索引现有的向量。对于新的查询向量,该索引可用于查找最近的邻居。
这种最近邻搜索并不完美,即它可能无法完美地找到所有前 k 个最近邻。
在此示例中,我们使用具有反向平坦索引 (IndexIVFFlat) 的 FAISS。它学习划分语料库嵌入
进入不同的集群(数量由 n_clusters 定义)。在搜索时,找到查询的匹配簇,并且只有向量
在这个集群中必须搜索最近的邻居。
此脚本将 ANN 的结果与精确的最近邻搜索进行比较并输出 Recall@k 值
以及 top-k 命中列表中缺失的结果。
请参阅 FAISS 存储库,了解如何安装 FAISS。
作为数据集,我们使用 Quora Duplicate Questions 数据集,其中包含大约 500k 个问题(仅使用了 100k):
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs。
作为嵌入模型,我们使用 SBERT 模型“quora-distilbert-multilingual”,
它与 100 种语言保持一致。也就是说,你可以用各种语言输入一个问题,它会
返回语料库中最接近的问题(语料库中的问题主要是英文)。

"""
from sentence_transformers import SentenceTransformer, util
import os
import csv
import pickle
import time
import faiss
import numpy as np


model_name = ‘quora-distilbert-multilingual‘
model = SentenceTransformer(model_name)

url = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv"
dataset_path = "quora_duplicate_questions.tsv"
max_corpus_size = 100000

embedding_cache_path = ‘quora-embeddings-{}-size-{}.pkl‘.format(model_name.replace(‘/‘, ‘_‘), max_corpus_size)


embedding_size = 768    #Size of embeddings
top_k_hits = 10         #Output k hits

#Defining our FAISS index 定义我们的 FAISS 指数 用于faiss的集群数量。 选择一个值 4*sqrt(N) 到 16*sqrt(N) - https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
#Number of clusters used for faiss. Select a value 4*sqrt(N) to 16*sqrt(N) - https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index

n_clusters = 1024 #We use Inner Product (dot-product) as Index. We will normalize our vectors to unit length, then is Inner Product equal to cosine similarity #我们使用内积(点积)作为索引。 我们将向量归一化为单位长度,然后内积等于余弦相似度 quantizer = faiss.IndexFlatIP(embedding_size) index = faiss.IndexIVFFlat(quantizer, embedding_size, n_clusters, faiss.METRIC_INNER_PRODUCT) #Number of clusters to explorer at search time. We will search for nearest neighbors in 3 clusters. 搜索时资源管理器的集群数量。 我们将在 3 个集群中搜索最近的邻居。 index.nprobe = 3 #Check if embedding cache path exists if not os.path.exists(embedding_cache_path): # Check if the dataset exists. If not, download and extract # Download dataset if needed if not os.path.exists(dataset_path): print("Download dataset") util.http_get(url, dataset_path) # Get all unique sentences from the file corpus_sentences = set() with open(dataset_path, encoding=‘utf8‘) as fIn: reader = csv.DictReader(fIn, delimiter=‘\t‘, quoting=csv.QUOTE_MINIMAL) for row in reader: corpus_sentences.add(row[‘question1‘]) if len(corpus_sentences) >= max_corpus_size: break corpus_sentences.add(row[‘question2‘]) if len(corpus_sentences) >= max_corpus_size: break corpus_sentences = list(corpus_sentences) print("Encode the corpus. This might take a while") corpus_embeddings = model.encode(corpus_sentences, show_progress_bar=True, convert_to_numpy=True) print("Store file on disc") with open(embedding_cache_path, "wb") as fOut: pickle.dump({‘sentences‘: corpus_sentences, ‘embeddings‘: corpus_embeddings}, fOut) else: print("Load pre-computed embeddings from disc") with open(embedding_cache_path, "rb") as fIn: cache_data = pickle.load(fIn) corpus_sentences = cache_data[‘sentences‘] corpus_embeddings = cache_data[‘embeddings‘] ### Create the FAISS index print("Start creating FAISS index") # First, we need to normalize vectors to unit length corpus_embeddings = corpus_embeddings / np.linalg.norm(corpus_embeddings, axis=1)[:, None] # Then we train the index to find a suitable clustering index.train(corpus_embeddings) # Finally we add all embeddings to the index index.add(corpus_embeddings) ######### Search in the index ########### print("Corpus loaded with {} sentences / embeddings".format(len(corpus_sentences))) while True: inp_question = input("Please enter a question: ") start_time = time.time() question_embedding = model.encode(inp_question) #FAISS works with inner product (dot product). When we normalize vectors to unit length, inner product is equal to cosine similarity question_embedding = question_embedding / np.linalg.norm(question_embedding) question_embedding = np.expand_dims(question_embedding, axis=0) # Search in FAISS. It returns a matrix with distances and corpus ids. distances, corpus_ids = index.search(question_embedding, top_k_hits) # We extract corpus ids and scores for the first query hits = [{‘corpus_id‘: id, ‘score‘: score} for id, score in zip(corpus_ids[0], distances[0])] hits = sorted(hits, key=lambda x: x[‘score‘], reverse=True) end_time = time.time() print("Input question:", inp_question) print("Results (after {:.3f} seconds):".format(end_time-start_time)) for hit in hits[0:top_k_hits]: print("\t{:.3f}\t{}".format(hit[‘score‘], corpus_sentences[hit[‘corpus_id‘]])) # Approximate Nearest Neighbor (ANN) is not exact, it might miss entries with high cosine similarity # Here, we compute the recall of ANN compared to the exact results correct_hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k_hits)[0] correct_hits_ids = set([hit[‘corpus_id‘] for hit in correct_hits]) ann_corpus_ids = set([hit[‘corpus_id‘] for hit in hits]) if len(ann_corpus_ids) != len(correct_hits_ids): print("Approximate Nearest Neighbor returned a different number of results than expected") recall = len(ann_corpus_ids.intersection(correct_hits_ids)) / len(correct_hits_ids) print("\nApproximate Nearest Neighbor Recall@{}: {:.2f}".format(top_k_hits, recall * 100)) if recall < 1: print("Missing results:") for hit in correct_hits[0:top_k_hits]: if hit[‘corpus_id‘] not in ann_corpus_ids: print("\t{:.3f}\t{}".format(hit[‘score‘], corpus_sentences[hit[‘corpus_id‘]])) print("\n\n========\n")
NLP(三十三):sentence-transformers句子相似度官方示例

 

 六、实战

NLP(三十三):sentence-transformers句子相似度官方示例
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import os
from root_path import root
import pandas as pd

class MySentenceBert():
    logging.basicConfig(format=‘%(asctime)s - %(message)s‘,
                        datefmt=‘%Y-%m-%d %H:%M:%S‘,
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    def __init__(self):

        self.train_batch_size = 16
        self.num_epochs = 4
        data_path = os.path.join(root, "data", "sim_data")
        self.train_data = pd.read_csv(os.path.join(data_path, "train.csv"), sep="\t")
        self.val_data = pd.read_csv(os.path.join(data_path, "val.csv"), sep="\t")
        self.test_data = pd.read_csv(os.path.join(data_path, "test.csv"), sep="\t")
        self.model_save_path = os.path.join(root, "chkpt", "sentence_bert_model" +
                                            datetime.now().strftime("_%Y_%m_%d_%H_%M"))

    def data_generator(self):
        logging.info("generator dataset")
        train_datas = []
        dev_datas = []
        test_datas = []
        for s1, s2, l in zip(self.train_data["s1"],
                             self.train_data["s2"],
                             self.train_data["y"]):
            train_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        for s1, s2, l in zip(self.val_data["s1"],
                             self.val_data["s2"],
                             self.val_data["y"]):
            dev_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        for s1, s2, l in zip(self.test_data["s1"],
                             self.test_data["s2"],
                             self.test_data["y"]):
            test_datas.append(InputExample(texts=[s1, s2], label=float(l)))
        return train_datas, dev_datas, test_datas

    def train(self, train_datas, dev_datas, model):
        train_dataloader = DataLoader(train_datas, shuffle=True, batch_size=self.train_batch_size)
        train_loss = losses.CosineSimilarityLoss(model=model)
        evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_datas, name=‘sts-dev‘)
        warmup_steps = math.ceil(len(train_dataloader) * self.num_epochs  * 0.1)
        logging.info("Warmup-steps: {}".format(warmup_steps))
        model.fit(train_objectives=[(train_dataloader, train_loss)],
                  evaluator=evaluator,
                  epochs=self.num_epochs,
                  evaluation_steps=1000,
                  warmup_steps=warmup_steps,
                  output_path=self.model_save_path)
    def test(self, test_samples):
        model = SentenceTransformer(self.model_save_path)
        test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name=‘sts-test‘)
        test_evaluator(model, output_path=self.model_save_path)

    def main(self):
        train_datas, dev_datas, test_datas = self.data_generator()

        model_name = os.path.join(root, "chkpt", "bert-base-chinese")
        word_embedding_model = models.Transformer(model_name)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                       pooling_mode_mean_tokens=True,
                                       pooling_mode_cls_token=False,
                                       pooling_mode_max_tokens=False)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
        self.train(train_datas, dev_datas, model)
        self.test(test_datas)

if __name__ == ‘__main__‘:
    MySentenceBert().main()
NLP(三十三):sentence-transformers句子相似度官方示例

 

NLP(三十三):sentence-transformers句子相似度官方示例
from sentence_transformers import SentenceTransformer, util
import os
import csv
import pickle
import time
from root_path import root
import json

class SemanticSearch():
    def __init__(self):
        model_name = os.path.join(root, "chkpt", "sentence_bert_model_2021_08_05_18_16")
        self.model = SentenceTransformer(model_name)
        embedding_cache_path = ‘semantic_search_embedding.pkl‘
        dataset_path = os.path.join(root, "data", "bert_data", "index.txt")
        with open(os.path.join(root, "config", "code_to_label.json"), "r", encoding="utf8") as f:
            self.d = json.load(f)
        self.sentences = list()
        self.code = list()
        if not os.path.exists(embedding_cache_path):
            with open(dataset_path, encoding=‘utf8‘) as fIn:
                for read_line in fIn:
                    read_line = read_line.split("\t")
                    self.sentences.append(read_line[0])
                    self.code.append(read_line[1].replace("\n", ""))
            print("Encode the corpus. This might take a while")
            self.embeddings = self.model.encode(self.sentences, show_progress_bar=True, convert_to_tensor=True)
            print("Store file on disc")
            with open(embedding_cache_path, "wb") as fOut:
                pickle.dump({‘sentences‘: self.sentences, ‘embeddings‘: self.embeddings, "code": self.code}, fOut)
        else:
            print("Load pre-computed embeddings from disc")
            with open(embedding_cache_path, "rb") as fIn:
                cache_data = pickle.load(fIn)
                self.sentences = cache_data[‘sentences‘]
                self.embeddings = cache_data[‘embeddings‘]
                self.code = cache_data["code"]

    def query(self, query):
        inp_question = query
        question_embedding = self.model.encode(inp_question, convert_to_tensor=True)
        hits = util.semantic_search(question_embedding, self.embeddings)
        hit = hits[0][0]  # Get the hits for the first query
        score = hit[‘score‘]
        text = self.sentences[hit[‘corpus_id‘]]
        kh_code = self.code[hit[‘corpus_id‘]]
        label = self.d[kh_code][1]
        return label,score,text

    def main(self):
        self.query("你好")


if __name__ == ‘__main__‘:
    SemanticSearch().main()
NLP(三十三):sentence-transformers句子相似度官方示例

 

NLP(三十三):sentence-transformers句子相似度官方示例

上一篇:Linux 的select、poll、epoll


下一篇:文件管理基础命令之一