faiss是Facebook开源的相似性搜索库,为稠密向量提供高效相似度搜索和聚类,支持十亿级别向量的搜索,是目前最为成熟的近似近邻搜索库
faiss不直接提供余弦距离计算,而是提供了欧式距离和点积,利用余弦距离公式,经过L2正则后的向量点积结果即为余弦距离,所以利用faiss计算余弦距离需要先对输入进行L2正则化
-
安装
参照官方开源安装https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
# CPU-only version $ conda install -c pytorch faiss-cpu $ pip install faiss-cpu # GPU(+CPU) version $ conda install -c pytorch faiss-gpu $ pip install faiss-cpu
-
常规计算余弦距离方式
常规一般使用sklearn包的cosine_similarity计算余弦距离,因为该包自动对向量进行L2正则,所以不要求输入必须为正则结果,代码如下:
## 计算余弦距离 from sklearn.metrics.pairwise import cosine_similarity from sklearn import preprocessing def get_cos_result(embeding_library, persons, embeding_search): simi = cosine_similarity(embeding_search, embeding_library) max_argmin = np.argmax(simi,axis=1) search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)] return search_speaker ## 对输入进行正则化,可以不用正则 def l2_normal(embeding): return preprocessing.normalize(embeding)
-
faiss的精确搜索
faiss并不提供计算与余弦距离,只提供了点积计算和欧式距离,所以在计算余弦距离时,需要对输入进行L2正则,代码如下:
import faiss from faiss import normalize_L2 def faiss_precise_search(embeding_library, persons, embeding_search,topk=1): ## 这里也可以使用上文的sklearn的包进行正则 normalize_L2(embeding_search) normalize_L2(embeding_library) # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = quantizer ## 要保证输入为np.float32格式 index.add(embeding_library.astype(np.float32)) library = {'persons': persons, 'index': index} st = time.time() distance,idx = library['index'].search(embeding_search,topk) print('precise search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results
-
faiss快速搜索
faiss提供了多种快速搜索的方式,这里介绍常用的一种加速搜索的方式:倒排索引,这种方式与ES快速搜索的方式类似,需要先使用k-means建立聚类中心,通过查询最近的聚类中心,然后比较聚类中所有向量得到相似向量,这里需要两个超参数,一个是聚类中心num_cells,一个是查找聚类中心的个数num_cells_in_search,具体代码如下
def faiss_fast_search(embeding_library, persons, embeding_search,topk=1): normalize_L2(embeding_search) normalize_L2(embeding_library) d = embeding_library.shape[1] num_cells = 50 num_cells_in_search = 5 # 声明量化器 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) # faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离 index = faiss.IndexIVFFlat(quantizer, d,min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) assert not index.is_trained index.train(embeding_library.astype(np.float32)) index.add(embeding_library.astype(np.float32)) index.nprobe = min(num_cells_in_search,len(persons)) library = {'persons': persons, 'index': index} st = time.time() distance, idx = library['index'].search(embeding_search, topk) print('fast search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results
-
整体代码
# -*- coding: utf-8 -*- import faiss from faiss import normalize_L2 from sklearn.metrics.pairwise import cosine_similarity from sklearn import preprocessing import numpy as np import time def l2_normal(embeding): return preprocessing.normalize(embeding) def get_cos_result(embeding_search, persons, embeding_library): simi = cosine_similarity(embeding_search, embeding_library) max_argmin = np.argmax(simi,axis=1) search_speaker = [[persons[id],simi[i][id]] for i, id in enumerate(max_argmin)] return search_speaker def faiss_precise_search(embeding_library, persons, embeding_search): normalize_L2(embeding_search) normalize_L2(embeding_library) # faiss.IndexFlatIP是内积 ;faiss.indexFlatL2是欧式距离 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = quantizer index.add(embeding_library.astype(np.float32)) library = {'persons': persons, 'index': index} st = time.time() distance,idx = library['index'].search(embeding_search,1) print('precise search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results def faiss_fast_search(embeding_library, persons, embeding_search,topk=1): normalize_L2(embeding_search) normalize_L2(embeding_library) num_cells = 500 num_cells_in_search = 10 quantizer = faiss.IndexFlatIP(embeding_library.shape[1]) index = faiss.IndexIVFFlat(quantizer, embeding_library.shape[1],min(num_cells, len(persons)),faiss.METRIC_INNER_PRODUCT) #faiss.METRIC_INNER_PRODUCT计算内积 faiss.METRIC_L2j计算欧式距离 assert not index.is_trained index.train(embeding_library.astype(np.float32)) index.add(embeding_library.astype(np.float32)) index.nprobe = min(num_cells_in_search,len(persons)) library = {'persons': persons, 'index': index} st = time.time() distance, idx = library['index'].search(embeding_search, topk) print('fast search:',time.time()-st) combined_results = [] for p in range(len(distance)): results = [[library["persons"][i], s] for i, s in zip(idx[p], distance[p]) if s >= 0][0] combined_results.append(results) return combined_results if __name__ == '__main__': d = 512 n_library = 100000 n_search = 1 embeding_library = np.random.random((n_library, d)).astype(np.float32) persons = ['Speak' + "%0d" % (i + 1) for i in range(n_library)] embeding_search = np.random.random((n_search, d)).astype(np.float32) print(faiss_fast_search(embeding_library, persons, embeding_search)) print(faiss_precise_search(embeding_library, persons, embeding_search)) st = time.time() print(get_cos_result(embeding_search, persons, embeding_library)) en1 = time.time() print(en1-st)