1、maven 依赖
<dependency>
<groupId>io.milvus</groupId>
<artifactId>milvus-sdk-java</artifactId>
<version>2.3.3</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
</exclusion>
</exclusions>
</dependency>
2、MivusService 封装了 基本操作
@Service
@Slf4j
public class MivusService {
@Autowired
MilvusServiceClient milvusClient;
private String clientId;
/**
* 同步搜索milvus
* @param collectionName 表名
* @param vectors 查询向量
* @param topK 最相似的向量个数
* @return
*/
public List<Long> search(String collectionName, List<List<Float>> vectors, Integer topK) {
Assert.notNull(collectionName, "collectionName is null");
Assert.notNull(vectors, "vectors is null");
Assert.notEmpty(vectors, "vectors is empty");
Assert.notNull(topK, "topK is null");
int nprobeVectorSize = vectors.get(0).size();
String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
SearchParam searchParam =
SearchParam.newBuilder().withCollectionName(collectionName)
.withParams(paramsInJson)
.withMetricType(MetricType.L2)
.withVectors(vectors)
.withVectorFieldName("embeddings")
.withTopK(topK)
.build();
R<SearchResults> searchResultsR = milvusClient.search(searchParam);
SearchResults searchResultsRData = searchResultsR.getData();
List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
return topksList;
}
/**
* 同步搜索milvus
* @param collectionName 表名
* @param vectors 查询向量
* @param topK 最相似的向量个数
* @return
*/
public List<Long> search1(String collectionName, List<List<Float>> vectors, Integer topK) {
Assert.notNull(collectionName, "collectionName is null");
Assert.notNull(vectors, "vectors is null");
Assert.notEmpty(vectors, "vectors is empty");
Assert.notNull(topK, "topK is null");
int nprobeVectorSize = vectors.get(0).size();
String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
SearchParam searchParam =
SearchParam.newBuilder().withCollectionName(collectionName)
.withParams(paramsInJson)
.withMetricType(MetricType.IP)
.withVectors(vectors)
.withVectorFieldName("embedding")
.withTopK(topK)
.build();
R<SearchResults> searchResultsR = milvusClient.search(searchParam);
SearchResults searchResultsRData = searchResultsR.getData();
List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
return topksList;
}
/**
* 同步搜索milvus,增加过滤条件搜索
*
* @param collectionName 表名
* @param vectors 查询向量
* @param topK 最相似的向量个数
* @param exp 过滤条件:status=1
* @return
*/
public List<Long> search2(String collectionName, List<List<Float>> vectors, Integer topK, String exp) {
Assert.notNull(collectionName, "collectionName is null");
Assert.notNull(vectors, "vectors is null");
Assert.notEmpty(vectors, "vectors is empty");
Assert.notNull(topK, "topK is null");
Assert.notNull(exp, "exp is null");
int nprobeVectorSize = vectors.get(0).size();
String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
SearchParam searchParam =
SearchParam.newBuilder().withCollectionName(collectionName)
.withParams(paramsInJson)
.withMetricType(MetricType.IP)
.withVectors(vectors)
.withExpr(exp)
.withVectorFieldName("embedding")
.withTopK(topK)
.build();
R<SearchResults> searchResultsR = milvusClient.search(searchParam);
SearchResults searchResultsRData = searchResultsR.getData();
List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
return topksList;
}
/**
* 异步搜索milvus
*
* @param collectionName 表名
* @param vectors 查询向量
* @param partitionList 最相似的向量个数
* @param topK
* @return
*/
public List<Long> searchAsync(String collectionName, List<List<Float>> vectors,
List<String> partitionList, Integer topK) throws ExecutionException, InterruptedException {
Assert.notNull(collectionName, "collectionName is null");
Assert.notNull(vectors, "vectors is null");
Assert.notEmpty(vectors, "vectors is empty");
Assert.notNull(partitionList, "partitionList is null");
Assert.notEmpty(partitionList, "partitionList is empty");
Assert.notNull(topK, "topK is null");
int nprobeVectorSize = vectors.get(0).size();
String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
SearchParam searchParam =
SearchParam.newBuilder().withCollectionName(collectionName)
.withParams(paramsInJson)
.withVectors(vectors)
.withTopK(topK)
.withPartitionNames(partitionList)
.build();
ListenableFuture<R<SearchResults>> listenableFuture = milvusClient.searchAsync(searchParam);
List<Long> resultIdsList = listenableFuture.get().getData().getResults().getTopksList();
return resultIdsList;
}
/**
* 获取分区集合
* @param collectionName 表名
* @return
*/
public List<String> getPartitionsList(String collectionName) {
Assert.notNull(collectionName, "collectionName is null");
ShowPartitionsParam searchParam = ShowPartitionsParam.newBuilder().withCollectionName(collectionName).build();
List<ByteString> byteStrings = milvusClient.showPartitions(searchParam).getData().getPartitionNamesList().asByteStringList();
List<String> partitionList = Lists.newLinkedList();
byteStrings.forEach(s -> {
partitionList.add(s.toStringUtf8());
});
return partitionList;
}
public void loadCollection(String collectionName) {
LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<RpcStatus> response = milvusClient.loadCollection(loadCollectionParam);
log.info("loadCollection {} is {}", collectionName, response.getData().getMsg());
}
public void releaseCollection(String collectionName) {
ReleaseCollectionParam param = ReleaseCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<RpcStatus> response = milvusClient.releaseCollection(param);
log.info("releaseCollection {} is {}", collectionName, response.getData().getMsg());
}
public void loadPartitions(String collectionName, List<String> partitionsName) {
LoadPartitionsParam build = LoadPartitionsParam.newBuilder()
.withCollectionName(collectionName)
.withPartitionNames(partitionsName)
.build();
R<RpcStatus> rpcStatusR = milvusClient.loadPartitions(build);
log.info("loadPartitions {} is {}", partitionsName, rpcStatusR.getData().getMsg());
}
public void releasePartitions(String collectionName, List<String> partitionsName) {
ReleasePartitionsParam build = ReleasePartitionsParam.newBuilder()
.withCollectionName(collectionName)
.withPartitionNames(partitionsName)
.build();
R<RpcStatus> rpcStatusR = milvusClient.releasePartitions(build);
log.info("releasePartition {} is {}", collectionName, rpcStatusR.getData().getMsg());
}
public boolean isExitCollection(String collectionName) {
HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<Boolean> response = milvusClient.hasCollection(hasCollectionParam);
Boolean isExists = response.getData();
log.info("collection {} is exists: {}", collectionName, isExists);
return isExists;
}
public Boolean creatCollection(String collectionName) {
// 主键字段
FieldType fieldType1 = FieldType.newBuilder()
.withName(Content.Field.ID)
.withDescription("primary key")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
// 文本字段
FieldType fieldType2 = FieldType.newBuilder()
.withName(Content.Field.CONTENT)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 向量字段
FieldType fieldType3 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_VECTOR)
.withDataType(DataType.FloatVector)
.withDimension(Content.FEATURE_DIM)
.build();
// 创建collection
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription("Schema of Content")
.withShardsNum(Content.SHARDS_NUM)
.addFieldType(fieldType1)
.addFieldType(fieldType2)
.addFieldType(fieldType3)
.build();
R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
return response.getData().getMsg().equals("Success");
}
public Boolean dropCollection(String collectionName) {
DropCollectionParam book = DropCollectionParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<RpcStatus> response = milvusClient.dropCollection(book);
return response.getData().getMsg().equals("Success");
}
public void createPartition(String collectionName, String partitionName) {
CreatePartitionParam param = CreatePartitionParam.newBuilder()
.withCollectionName(collectionName)
.withPartitionName(partitionName)
.build();
R<RpcStatus> partition = milvusClient.createPartition(param);
String msg = partition.getData().getMsg();
log.info("create partition: {} in collection: {} is: {}", partition, collectionName, msg);
}
public Boolean createIndex(String collectionName) {
// IndexType
final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
// ExtraParam 建议值为 4 × sqrt(n), 其中 n 指 segment 最多包含的 entity 条数。
final String INDEX_PARAM = "{\"nlist\":16384}";
long startIndexTime = System.currentTimeMillis();
R<RpcStatus> response = milvusClient.createIndex(CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withIndexName(Content.CONTENT_INDEX)
.withFieldName(Content.Field.CONTENT_VECTOR)
.withMetricType(MetricType.L2)
.withIndexType(INDEX_TYPE)
.withExtraParam(INDEX_PARAM)
.withSyncMode(Boolean.TRUE)
.withSyncWaitingInterval(500L)
.withSyncWaitingTimeout(30L)
.build());
long endIndexTime = System.currentTimeMillis();
log.info("Succeed in " + (endIndexTime - startIndexTime) / 1000.00 + " seconds!");
log.info("createIndex --->>> {} ", response.toString());
GetIndexBuildProgressParam build = GetIndexBuildProgressParam.newBuilder()
.withCollectionName(collectionName)
.build();
R<GetIndexBuildProgressResponse> idnexResp = milvusClient.getIndexBuildProgress(build);
log.info("getIndexBuildProgress --->>> {}", idnexResp.getStatus());
return response.getData().getMsg().equals("Success");
}
public ReplyMsg insert(String collectionName, List<InsertParam.Field> fields) {
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withFields(fields)
.build();
R<MutationResult> mutationResultR = milvusClient.insert(insertParam);
log.info("Flushing...");
long startFlushTime = System.currentTimeMillis();
milvusClient.flush(FlushParam.newBuilder()
.withCollectionNames(Collections.singletonList(collectionName))
.withSyncFlush(true)
.withSyncFlushWaitingInterval(50L)
.withSyncFlushWaitingTimeout(30L)
.build());
long endFlushTime = System.currentTimeMillis();
log.info("Succeed in " + (endFlushTime - startFlushTime) / 1000.00 + " seconds!");
if (mutationResultR.getStatus() == 0){
long insertCnt = mutationResultR.getData().getInsertCnt();
log.info("Successfully! Total number of entities inserted: {} ", insertCnt);
return ReplyMsg.ofSuccess("success", insertCnt);
}
log.error("InsertRequest failed!");
return ReplyMsg.ofErrorMsg("InsertRequest failed!");
}
public List<List<SearchResultVo>> searchTopKSimilarity(SearchParamVo searchParamVo) {
log.info("------search TopK Similarity------");
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(searchParamVo.getCollectionName())
.withMetricType(MetricType.L2)
.withOutFields(searchParamVo.getOutputFields())
.withTopK(searchParamVo.getTopK())
.withVectors(searchParamVo.getQueryVectors())
.withVectorFieldName(Content.Field.CONTENT_VECTOR)
.withParams(searchParamVo.getParams())
.build();
R<SearchResults> respSearch = milvusClient.search(searchParam);
if (respSearch.getData() == null) {
return null;
}
log.info("------ process query results ------");
SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
List<List<SearchResultVo>> result = new ArrayList<>();
for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
List<SearchResultVo> list = new ArrayList<>();
for (int j = 0; j < scores.size(); ++j) {
SearchResultsWrapper.IDScore score = scores.get(j);
QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
long longID = score.getLongID();
float distance = score.getScore();
String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
log.info("Content: " + content);
list.add(SearchResultVo.builder().id(longID).score(distance).conent(content).build());
}
result.add(list);
}
log.info("Successfully!");
return result;
}
public Boolean creatCollectionERP(String collectionName) {
// 主键字段
FieldType fieldType1 = FieldType.newBuilder()
.withName(Content.Field.ID)
.withDescription("primary key")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
// 文本字段
FieldType fieldType2 = FieldType.newBuilder()
.withName(Content.Field.CONTENT)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 向量字段
FieldType fieldType3 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_VECTOR)
.withDataType(DataType.FloatVector)
.withDimension(Content.FEATURE_DIM)
.build();
FieldType fieldType4 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_ANSWER)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType5 = FieldType.newBuilder()
.withName(Content.Field.TITLE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType6 = FieldType.newBuilder()
.withName(Content.Field.PARAM)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType7 = FieldType.newBuilder()
.withName(Content.Field.TYPE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 创建collection
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription("Schema of Content ERP")
.withShardsNum(Content.SHARDS_NUM)
.addFieldType(fieldType1)
.addFieldType(fieldType2)
.addFieldType(fieldType3)
.addFieldType(fieldType4)
.addFieldType(fieldType5)
.addFieldType(fieldType6)
.addFieldType(fieldType7)
.build();
R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
return response.getData().getMsg().equals("Success");
}
public Boolean creatCollectionERPCLIP(String collectionName) {
// 主键字段
FieldType fieldType1 = FieldType.newBuilder()
.withName(Content.Field.ID)
.withDescription("primary key")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
// 文本字段
FieldType fieldType2 = FieldType.newBuilder()
.withName(Content.Field.CONTENT)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 向量字段
FieldType fieldType3 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_VECTOR)
.withDataType(DataType.FloatVector)
.withDimension(Content.FEATURE_DIM_CLIP)
.build();
FieldType fieldType4 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_ANSWER)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType5 = FieldType.newBuilder()
.withName(Content.Field.TITLE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType6 = FieldType.newBuilder()
.withName(Content.Field.PARAM)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType7 = FieldType.newBuilder()
.withName(Content.Field.TYPE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType8 = FieldType.newBuilder()
.withName(Content.Field.LABEL)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 创建collection
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription("Schema of Content ERP")
.withShardsNum(Content.SHARDS_NUM)
.addFieldType(fieldType1)
.addFieldType(fieldType2)
.addFieldType(fieldType3)
.addFieldType(fieldType4)
.addFieldType(fieldType5)
.addFieldType(fieldType6)
.addFieldType(fieldType7)
.addFieldType(fieldType8)
.build();
R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
return response.getData().getMsg().equals("Success");
}
public Boolean creatCollectionERPNLP(String collectionName) {
// 主键字段
FieldType fieldType1 = FieldType.newBuilder()
.withName(Content.Field.ID)
.withDescription("primary key")
.withDataType(DataType.Int64)
.withPrimaryKey(true)
.withAutoID(true)
.build();
// 文本字段
FieldType fieldType2 = FieldType.newBuilder()
.withName(Content.Field.CONTENT)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 向量字段
FieldType fieldType3 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_VECTOR)
.withDataType(DataType.FloatVector)
.withDimension(Content.FEATURE_DIM_CLIP)
.build();
FieldType fieldType4 = FieldType.newBuilder()
.withName(Content.Field.CONTENT_ANSWER)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType5 = FieldType.newBuilder()
.withName(Content.Field.TITLE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType6 = FieldType.newBuilder()
.withName(Content.Field.PARAM)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType7 = FieldType.newBuilder()
.withName(Content.Field.TYPE)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
FieldType fieldType8 = FieldType.newBuilder()
.withName(Content.Field.LABEL)
.withDataType(DataType.VarChar)
.withMaxLength(Content.MAX_LENGTH)
.build();
// 创建collection
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription("Schema of Content ERP")
.withShardsNum(Content.SHARDS_NUM)
.addFieldType(fieldType1)
.addFieldType(fieldType2)
.addFieldType(fieldType3)
.addFieldType(fieldType4)
.addFieldType(fieldType5)
.addFieldType(fieldType6)
.addFieldType(fieldType7)
.addFieldType(fieldType8)
.build();
R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);
log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
return response.getData().getMsg().equals("Success");
}
public List<List<SearchERPResultVo>> searchERPTopKSimilarity(SearchERPParamVo searchParamVo) {
log.info("------search ERP TopK Similarity------");
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(searchParamVo.getCollectionName())
.withMetricType(MetricType.L2)
.withOutFields(searchParamVo.getOutputFields())
.withTopK(searchParamVo.getTopK())
.withVectors(searchParamVo.getQueryVectors())
.withVectorFieldName(Content.Field.CONTENT_VECTOR)
.withParams(searchParamVo.getParams())
.build();
R<SearchResults> respSearch = milvusClient.search(searchParam);
if (respSearch.getData() == null) {
return null;
}
log.info("------ process query results ------");
SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
List<List<SearchERPResultVo>> result = new ArrayList<>();
for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
List<SearchERPResultVo> list = new ArrayList<>();
for (int j = 0; j < scores.size(); ++j) {
SearchResultsWrapper.IDScore score = scores.get(j);
QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
long longID = score.getLongID();
float distance = score.getScore();
String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
log.info("Content: " + content);
list.add(SearchERPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
}
result.add(list);
}
log.info("Successfully!");
return result;
}
public List<List<SearchNLPResultVo>> searchNLPTopKSimilarity(SearchNLPParamVo searchParamVo) {
log.info("------search ERP TopK Similarity------");
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName(searchParamVo.getCollectionName())
.withMetricType(MetricType.L2)
.withOutFields(searchParamVo.getOutputFields())
.withTopK(searchParamVo.getTopK())
.withVectors(searchParamVo.getQueryVectors())
.withVectorFieldName(Content.Field.CONTENT_VECTOR)
.withParams(searchParamVo.getParams())
.withExpr(searchParamVo.getExpr())
.build();
R<SearchResults> respSearch = milvusClient.search(searchParam);
if (respSearch.getData() == null) {
return null;
}
log.info("------ process query results ------");
SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
List<List<SearchNLPResultVo>> result = new ArrayList<>();
for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
List<SearchNLPResultVo> list = new ArrayList<>();
for (int j = 0; j < scores.size(); ++j) {
SearchResultsWrapper.IDScore score = scores.get(j);
QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
long longID = score.getLongID();
float distance = score.getScore();
String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
log.info("Content: " + content);
list.add(SearchNLPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
}
result.add(list);
}
log.info("Successfully!");
return result;
}
}
3、测试用例
MilvusServiceERPNLPTest
@SpringBootTest(classes = {DataChatgptApplication.class}, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class MilvusServiceERPNLPTest {
@Autowired
MivusService milvusService;
@Autowired
MilvusClient milvusClient;
@Test
void isExitCollection() {
boolean mediumArticles = milvusService.isExitCollection(Content.COLLECTION_NAME_NLP);
Assertions.assertTrue(mediumArticles);
}
@Test
void creatCollection() {
Boolean created = milvusService.creatCollectionERPNLP(Content.COLLECTION_NAME_NLP);
Assertions.assertTrue(created);
}
@Test
void createIndex(){
Boolean index = milvusService.createIndex(Content.COLLECTION_NAME_NLP);
Assertions.assertTrue(index);
}
@Test
public void insertVector(){
List<String> sentenceList = new ArrayList<>();
sentenceList.add("网址是多少");
List<String> contentAnswerList = new ArrayList<>();
contentAnswerList.add("/home.ashx");
List<String> titleList = new ArrayList<>();
titleList.add("网址");
List<String> paramList = new ArrayList<>();
paramList.add("");
List<String> typeList = new ArrayList<>();
typeList.add("0");
List<String> labelList = new ArrayList<>();
labelList.add("操作直达");
PaddleNewTextVo paddleNewTextVo = null;
try {
paddleNewTextVo = getVectorsLists(sentenceList);
if (paddleNewTextVo == null) {
// 获取不到再重试下
paddleNewTextVo = getVectorsLists(sentenceList);
}
List<List<Double>> vectors = paddleNewTextVo.getVector();
List<List<Float>> floatVectors = new ArrayList<>();
for (List<Double> innerList : vectors) {
List<Float> floatInnerList = new ArrayList<>();
for (Double value : innerList) {
floatInnerList.add(value.floatValue());
}
floatVectors.add(floatInnerList);
}
// 2.准备插入向量数据库
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, floatVectors));
fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
fields.add(new InsertParam.Field(Content.Field.LABEL, labelList));
// 3.执行操作
milvusService.insert(Content.COLLECTION_NAME_NLP, fields);
} catch (ApiException e) {
System.out.println(e.getMessage());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private static PaddleNewTextVo getVectorsLists(List<String> sentenceList) throws IOException {
String url = "http://192.168.1.243:6001/"; //paddle
URL obj = new URL(url);
HttpURLConnection con = (HttpURLConnection) obj.openConnection();
// 设置超时时间
con.setConnectTimeout(50000);
con.setReadTimeout(200000);
con.setRequestMethod("POST");
con.setRequestProperty("Content-Type", "application/json");
con.setDoOutput(true);
ObjectMapper objectParmMapper = new ObjectMapper();
// 创建一个Map结构表示您的数据
Map<String, List<Map<String, String>>> dataMap = new HashMap<>();
dataMap.put("data", sentenceList.stream()
.map(sentence -> Collections.singletonMap("text", sentence))
.collect(Collectors.toList()));
String jsonData = null;
try {
// 将Map转换为JSON字符串
jsonData = objectParmMapper.writeValueAsString(dataMap);
} catch (JsonProcessingException e) {
System.err.println("Error converting to JSON: " + e.getMessage());
}
String data = jsonData;
try(OutputStream os = con.getOutputStream()) {
byte[] input = data.getBytes("utf-8");
os.write(input, 0, input.length);
}
int responseCode = con.getResponseCode();
System.out.println("Response Code: " + responseCode);
PaddleNewTextVo paddleNewTextVo = null;
if (responseCode == HttpURLConnection.HTTP_OK) { // 200表示成功
BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()));
String inputLine;
StringBuilder content = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
content.append(inputLine);
}
in.close();
try {
String contentStr = content.toString();
// 直接解析JSON字符串到PaddleTextVo实例
paddleNewTextVo = JSON.parseObject(contentStr, PaddleNewTextVo.class);
} catch (Exception e) {
System.err.println("Error parsing JSON: " + e.getMessage());
}
} else {
System.out.println("Error Response Code: " + responseCode);
BufferedReader errorReader = new BufferedReader(new InputStreamReader(con.getErrorStream()));
String errorMessage;
while ((errorMessage = errorReader.readLine()) != null) {
System.out.println("Error Message: " + errorMessage);
}
errorReader.close();
}
return paddleNewTextVo;
}
@Test
void searchTest(){
// 0.加载向量集合
milvusService.loadCollection(Content.COLLECTION_NAME_NLP);
try {
List<String> sentenceList = new ArrayList<>();
sentenceList.add("XX列表");
String label = "操作直达";
// 1.获得向量
// List<List<Float>> vectors = getVectorsLists(sentenceList);
List<List<Float>> vectors = new ArrayList<>();
SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
.collectionName(Content.COLLECTION_NAME_NLP)
.queryVectors(vectors)
.expr("label == '" + label + "'")
.topK(3)
.build();
// 2.在向量数据库中进行搜索内容知识
List<List<SearchNLPResultVo>> lists = milvusService.searchNLPTopKSimilarity(searchParamVo);
lists.forEach(searchResultVos -> {
searchResultVos.forEach(searchResultVo -> {
System.out.println(searchResultVo.getContent());
System.out.println(searchResultVo.getContentAnswer());
System.out.println(searchResultVo.getTitle());
System.out.println(searchResultVo.getLabel());
});
});
} catch (ApiException e) {
System.out.println(e.getMessage());
} /*catch (IOException e) {
throw new RuntimeException(e);
}
*/
}
@Test
public void insertTextVector() throws IOException {
List<String> titleList = new ArrayList<>();
List<String> sentenceList = new ArrayList<>();
List<String> contentAnswerList = new ArrayList<>();
List<String> paramList = new ArrayList<>();
List<String> typeList = new ArrayList<>();
String filePath = "src/main/resources/data/text.txt";
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
// 使用4个竖线(||||)作为分隔符
String line;
while ((line = reader.readLine()) != null) {
String[] parts = line.split("\\|\\|\\|\\|");
if (parts.length >= 3) {
titleList.add(parts[0].trim());
sentenceList.add(parts[1].trim());
contentAnswerList.add(parts[2].trim());
paramList.add("");
typeList.add("2");
} else {
System.out.println("Warning: Invalid format on line: " + line);
}
}
// 打印或处理列表内容
System.out.println("Title List: " + titleList);
System.out.println("Sentence List: " + sentenceList);
System.out.println("Content Answer List: " + contentAnswerList);
} catch (IOException e) {
System.err.println("Error reading file: " + e.getMessage());
}
try {
// 1.获得向量
TextEmbeddingParam param = TextEmbeddingParam
.builder()
.model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
.texts(sentenceList).build();
TextEmbedding textEmbedding = new TextEmbedding();
TextEmbeddingResult result = textEmbedding.call(param);
List<List<Float>> vectors = new ArrayList<>();
for (int i = 0; i < result.getOutput().getEmbeddings().size(); i++) {
List<Double> vector = result.getOutput().getEmbeddings().get(i).getEmbedding();
List<Float> floatVector = vector.stream()
.map(Double::floatValue)
.collect(Collectors.toList());
vectors.add(floatVector);
}
// 2.准备插入向量数据库
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, vectors));
fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
// 3.执行操作
milvusService.insert(Content.COLLECTION_NAME_NLP, fields);
} catch (ApiException | NoApiKeyException e) {
System.out.println(e.getMessage());
}
}
@Test
void ChatBasedContentTest() throws NoApiKeyException, InputRequiredException, InterruptedException {
// 0.加载向量集合
milvusService.loadCollection(Content.COLLECTION_NAME_NLP);
try {
String question = "查询订单";
List<String> sentenceList = new ArrayList<>();
sentenceList.add(question);
// 1.获得向量
TextEmbeddingParam param = TextEmbeddingParam
.builder()
.model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
.texts(sentenceList).build();
TextEmbedding textEmbedding = new TextEmbedding();
TextEmbeddingResult result = textEmbedding.call(param);
List<Double> vector = result.getOutput().getEmbeddings().get(0).getEmbedding();
List<Float> floatVector = vector.stream()
.map(Double::floatValue)
.collect(Collectors.toList());
List<List<Float>> vectors = Collections.singletonList(floatVector);
SearchERPParamVo searchParamVo = SearchERPParamVo.builder()
.collectionName(Content.COLLECTION_NAME_NLP)
.queryVectors(vectors)
.topK(3)
.build();
// 2.在向量数据库中进行搜索内容知识
StringBuffer buffer = new StringBuffer();
List<List<SearchERPResultVo>> lists = milvusService.searchERPTopKSimilarity(searchParamVo);
lists.forEach(searchResultVos -> {
searchResultVos.forEach(searchResultVo -> {
buffer.append("问题: " + searchResultVo.getContent());
buffer.append("答案: " + searchResultVo.getContentAnswer());
});
});
// 3.进行对话
String prompt = "请你充分理解下面的内容,然后回答问题, 要求仅返回答案[]中内容:";
String content = buffer.toString();
String resultQwen = streamCallWithCallback(prompt + content + question);
// System.out.println(resultQwen);
} catch (ApiException | NoApiKeyException e) {
System.out.println(e.getMessage());
}
}
public static String streamCallWithCallback(String content)
throws NoApiKeyException, ApiException, InputRequiredException,InterruptedException {
Constants.apiKey="sk-2106098eed1f43c9bde754f3e87038a2";
Generation gen = new Generation();
Message userMsg = Message
.builder()
.role(Role.USER.getValue())
.content(content)
.build();
QwenParam param = QwenParam
.builder()
.model(Generation.Models.QWEN_PLUS)
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.messages(Arrays.asList(userMsg))
.topP(0.8)
.incrementalOutput(true) // get streaming output incrementally
.build();
Semaphore semaphore = new Semaphore(0);
StringBuilder fullContent = new StringBuilder();
gen.streamCall(param, new ResultCallback<GenerationResult>() {
@Override
public void onEvent(GenerationResult message) {
fullContent.append(message.getOutput().getChoices().get(0).getMessage().getContent());
System.out.println(message);
}
@Override
public void onError(Exception err){
System.out.println(String.format("Exception: %s", err.getMessage()));
semaphore.release();
}
@Override
public void onComplete(){
System.out.println("Completed");
semaphore.release();
}
});
semaphore.acquire();
System.out.println("Full content: \n" + fullContent.toString());
return fullContent.toString();
}
@Test
void loadData() throws IOException {
// Read the dataset file
String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
// Load dataset
JSONObject dataset = JSON.parseObject(content);
List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 2);
System.out.println(rows);
}
public String readFileToString(String filePath) throws IOException {
return new String(Files.readAllBytes(Paths.get(filePath)), StandardCharsets.UTF_8);
}
public static List<JSONObject> getRows(JSONArray dataset, int counts) {
List<JSONObject> rows = new ArrayList<>();
for (int i = 0; i < counts; i++) {
JSONObject row = dataset.getJSONObject(i);
List<Float> vectors = row.getJSONArray("title_vector").toJavaList(Float.class);
Long reading_time = row.getLong("reading_time");
Long claps = row.getLong("claps");
Long responses = row.getLong("responses");
row.put("title_vector", vectors);
row.put("reading_time", reading_time);
row.put("claps", claps);
row.put("responses", responses);
row.remove("id");
rows.add(row);
}
return rows;
}
@Test
void getFileds() throws IOException {
String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
// Load dataset
JSONObject dataset = JSON.parseObject(content);
List<InsertParam.Field> field = getFields(dataset.getJSONArray("rows"), 1);
System.out.println(field);
}
public static List<InsertParam.Field> getFields(JSONArray dataset, int counts) {
List<InsertParam.Field> fields = new ArrayList<>();
List<String> titles = new ArrayList<>();
List<List<Float>> title_vectors = new ArrayList<>();
List<String> links = new ArrayList<>();
List<Long> reading_times = new ArrayList<>();
List<String> publications = new ArrayList<>();
List<Long> claps_list = new ArrayList<>();
List<Long> responses_list = new ArrayList<>();
for (int i = 0; i < counts; i++) {
JSONObject row = dataset.getJSONObject(i);
titles.add(row.getString("title"));
title_vectors.add(row.getJSONArray("title_vector").toJavaList(Float.class));
links.add(row.getString("link"));
reading_times.add(row.getLong("reading_time"));
publications.add(row.getString("publication"));
claps_list.add(row.getLong("claps"));
responses_list.add(row.getLong("responses"));
}
fields.add(new InsertParam.Field("title", titles));
fields.add(new InsertParam.Field("title_vector", title_vectors));
fields.add(new InsertParam.Field("link", links));
fields.add(new InsertParam.Field("reading_time", reading_times));
fields.add(new InsertParam.Field("publication", publications));
fields.add(new InsertParam.Field("claps", claps_list));
fields.add(new InsertParam.Field("responses", responses_list));
return fields;
}
@Test
void searchTopKSimilarity() throws IOException {
// Search data
String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
// Load dataset
JSONObject dataset = JSON.parseObject(content);
List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 10);
// You should include the following in the main function
List<List<Float>> queryVectors = new ArrayList<>();
List<Float> queryVector = rows.get(0).getJSONArray("title_vector").toJavaList(Float.class);
queryVectors.add(queryVector);
// Prepare the outputFields
List<String> outputFields = new ArrayList<>();
outputFields.add("title");
outputFields.add("link");
// Search vectors in a collection
SearchParam searchParam = SearchParam.newBuilder()
.withCollectionName("medium_articles")
.withVectorFieldName("title_vector")
.withVectors(queryVectors)
.withExpr("claps > 30 and reading_time < 10")
.withTopK(3)
.withMetricType(MetricType.L2)
.withParams("{\"nprobe\":10,\"offset\":2, \"limit\":3}")
.withConsistencyLevel(ConsistencyLevelEnum.BOUNDED)
.withOutFields(outputFields)
.build();
R<SearchResults> response = milvusClient.search(searchParam);
SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
System.out.println("Search results");
for (int i = 0; i < queryVectors.size(); ++i) {
List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
for (int j = 0; j < scores.size(); ++j) {
SearchResultsWrapper.IDScore score = scores.get(j);
QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
System.out.println("Top " + j + " ID:" + score.getLongID() + " Distance:" + score.getScore());
System.out.println("Title: " + rowRecord.get("title"));
System.out.println("Link: " + rowRecord.get("link"));
}
}
}
}
4、查询
// 先根据向量查询语义相近的语料
List<Question> questionList = mivusService.searchNewPaddleQuestion(req.getMessage(), "1", appType);
**
* 根据问题进行向量查询,采用Paddle服务 采用新的文本分类方法
* @param question 用户的问题文本
* @return 相关的初始问题知识列表
*/
public List<Question> searchNewPaddleQuestion(String question, String type, String appType) {
// 0.加载向量集合
String collection = Content.COLLECTION_NAME_NLP;
if (appType.equals("1")) {
collection = Content.COLLECTION_NAME_NLP_APP;
}
loadCollection(collection);
List<Question> resultList = new LinkedList<>();
PaddleNewTextVo paddleNewTextVo = null;
try {
List<String> sentenceList = new ArrayList<>();
sentenceList.add(question);
// 1.获得向量
paddleNewTextVo = getNewNLPVectorsLists(sentenceList);
log.info("实时向量值 : {}", paddleNewTextVo.getPredictedList());
List<List<Double>> vectors = paddleNewTextVo.getVector();
List<List<Float>> floatVectors = new ArrayList<>();
for (List<Double> innerList : vectors) {
List<Float> floatInnerList = new ArrayList<>();
for (Double value : innerList) {
floatInnerList.add(value.floatValue());
}
floatVectors.add(floatInnerList);
}
List<Integer> predictedList = paddleNewTextVo.getPredictedList();
List<String> labelStrings = new ArrayList<>();
HashSet<Integer> setType = new HashSet();
int topK = 3;
if(!predictedList.isEmpty()) {
// 去重
for (Integer number : predictedList) {
setType.add(number);
if (number == 2) {
// 如何是 2
topK = 1;
}
}
for (Integer label : setType) {
labelStrings.add("'" + label + "'");
}
}
String typeResult = "[" + String.join(", ", labelStrings) + "]";
SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
.collectionName(collection)
//.expr("type == '" + type + "'")
.expr("type in ['0','1','2']")
//.expr("type in " + typeResult + " ")
.queryVectors(floatVectors)
.topK(topK)
.build();
// 2.在向量数据库中进行搜索内容知识
List<List<SearchNLPResultVo>> lists = searchNLPERPTopKSimilarity(searchParamVo);
lists.forEach(searchResultVos -> {
searchResultVos.forEach(searchResultVo -> {
log.info(searchResultVo.getContent());
log.info(searchResultVo.getContentAnswer());
Question question1 = new Question();
question1.setQuestionId(Long.valueOf(searchResultVo.getId()));
question1.setQuestion(searchResultVo.getContent());
question1.setAnswer(searchResultVo.getContentAnswer());
question1.setTitle(searchResultVo.getTitle());
question1.setParam(searchResultVo.getParam());
question1.setType(searchResultVo.getType());
question1.setLabel(searchResultVo.getLabel());
resultList.add(question1);
});
});
} catch (ApiException | IOException e) {
log.error(e.getMessage());
}
// 将查询到的结果转换为之前构造的 Question 的格式返回给前端
return resultList;
}