Milvus 基本操作

1、maven 依赖

<dependency>
            <groupId>io.milvus</groupId>
            <artifactId>milvus-sdk-java</artifactId>
            <version>2.3.3</version>
            <exclusions>
                <exclusion>
                    <groupId>org.slf4j</groupId>
                    <artifactId>slf4j-api</artifactId>
                </exclusion>
                <exclusion>
                    <groupId>org.apache.logging.log4j</groupId>
                    <artifactId>log4j-slf4j-impl</artifactId>
                </exclusion>
            </exclusions>
        </dependency>

2、MivusService 封装了 基本操作

@Service
@Slf4j
public class MivusService {


    @Autowired
    MilvusServiceClient milvusClient;

    private String clientId;

    /**
     * 同步搜索milvus
     * @param collectionName 表名
     * @param vectors 查询向量
     * @param topK 最相似的向量个数
     * @return
     */
    public List<Long> search(String collectionName, List<List<Float>> vectors, Integer topK) {

        Assert.notNull(collectionName, "collectionName  is null");
        Assert.notNull(vectors, "vectors is null");
        Assert.notEmpty(vectors, "vectors is empty");
        Assert.notNull(topK, "topK is null");
        int nprobeVectorSize = vectors.get(0).size();
        String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
        SearchParam searchParam =
                SearchParam.newBuilder().withCollectionName(collectionName)
                        .withParams(paramsInJson)
                        .withMetricType(MetricType.L2)
                        .withVectors(vectors)
                        .withVectorFieldName("embeddings")
                        .withTopK(topK)
                        .build();

        R<SearchResults> searchResultsR = milvusClient.search(searchParam);
        SearchResults searchResultsRData = searchResultsR.getData();
        List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
        return topksList;
    }

    /**
     * 同步搜索milvus
     * @param collectionName 表名
     * @param vectors 查询向量
     * @param topK 最相似的向量个数
     * @return
     */
    public List<Long> search1(String collectionName, List<List<Float>> vectors, Integer topK) {

        Assert.notNull(collectionName, "collectionName  is null");
        Assert.notNull(vectors, "vectors is null");
        Assert.notEmpty(vectors, "vectors is empty");
        Assert.notNull(topK, "topK is null");
        int nprobeVectorSize = vectors.get(0).size();
        String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
        SearchParam searchParam =
                SearchParam.newBuilder().withCollectionName(collectionName)
                        .withParams(paramsInJson)
                        .withMetricType(MetricType.IP)
                        .withVectors(vectors)
                        .withVectorFieldName("embedding")
                        .withTopK(topK)
                        .build();

        R<SearchResults> searchResultsR = milvusClient.search(searchParam);
        SearchResults searchResultsRData = searchResultsR.getData();
        List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
        return topksList;
    }


    /**
     * 同步搜索milvus,增加过滤条件搜索
     *
     * @param collectionName 表名
     * @param vectors 查询向量
     * @param topK 最相似的向量个数
     * @param exp 过滤条件:status=1
     * @return
     */
    public List<Long> search2(String collectionName, List<List<Float>> vectors, Integer topK, String exp) {
        Assert.notNull(collectionName, "collectionName  is null");
        Assert.notNull(vectors, "vectors is null");
        Assert.notEmpty(vectors, "vectors is empty");
        Assert.notNull(topK, "topK is null");
        Assert.notNull(exp, "exp is null");
        int nprobeVectorSize = vectors.get(0).size();
        String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
        SearchParam searchParam =
                SearchParam.newBuilder().withCollectionName(collectionName)
                        .withParams(paramsInJson)
                        .withMetricType(MetricType.IP)
                        .withVectors(vectors)
                        .withExpr(exp)
                        .withVectorFieldName("embedding")
                        .withTopK(topK)
                        .build();

        R<SearchResults> searchResultsR = milvusClient.search(searchParam);
        SearchResults searchResultsRData = searchResultsR.getData();
        List<Long> topksList = searchResultsRData.getResults().getIds().getIntId().getDataList();
        return topksList;
    }

    /**
     * 异步搜索milvus
     *
     * @param collectionName 表名
     * @param vectors 查询向量
     * @param partitionList 最相似的向量个数
     * @param topK
     * @return
     */
    public List<Long> searchAsync(String collectionName, List<List<Float>> vectors,
                                  List<String> partitionList, Integer topK) throws ExecutionException, InterruptedException {

        Assert.notNull(collectionName, "collectionName  is null");
        Assert.notNull(vectors, "vectors is null");
        Assert.notEmpty(vectors, "vectors is empty");
        Assert.notNull(partitionList, "partitionList is null");
        Assert.notEmpty(partitionList, "partitionList is empty");
        Assert.notNull(topK, "topK is null");
        int nprobeVectorSize = vectors.get(0).size();
        String paramsInJson = "{\"nprobe\": " + nprobeVectorSize + "}";
        SearchParam searchParam =
                SearchParam.newBuilder().withCollectionName(collectionName)
                        .withParams(paramsInJson)
                        .withVectors(vectors)
                        .withTopK(topK)
                        .withPartitionNames(partitionList)
                        .build();
        ListenableFuture<R<SearchResults>> listenableFuture = milvusClient.searchAsync(searchParam);

        List<Long> resultIdsList = listenableFuture.get().getData().getResults().getTopksList();

        return resultIdsList;
    }

    /**
     * 获取分区集合
     * @param collectionName 表名
     * @return
     */
    public List<String> getPartitionsList(String collectionName) {
        Assert.notNull(collectionName, "collectionName  is null");
        ShowPartitionsParam searchParam = ShowPartitionsParam.newBuilder().withCollectionName(collectionName).build();
        List<ByteString> byteStrings = milvusClient.showPartitions(searchParam).getData().getPartitionNamesList().asByteStringList();
        List<String> partitionList = Lists.newLinkedList();
        byteStrings.forEach(s -> {
            partitionList.add(s.toStringUtf8());
        });
        return partitionList;
    }



    public void loadCollection(String collectionName) {
        LoadCollectionParam loadCollectionParam = LoadCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .build();
        R<RpcStatus> response = milvusClient.loadCollection(loadCollectionParam);
        log.info("loadCollection {} is {}", collectionName, response.getData().getMsg());
    }


    public void releaseCollection(String collectionName) {
        ReleaseCollectionParam param = ReleaseCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .build();
        R<RpcStatus> response = milvusClient.releaseCollection(param);
        log.info("releaseCollection {} is {}", collectionName, response.getData().getMsg());
    }


    public void loadPartitions(String collectionName, List<String> partitionsName) {
        LoadPartitionsParam build = LoadPartitionsParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionNames(partitionsName)
                .build();
        R<RpcStatus> rpcStatusR = milvusClient.loadPartitions(build);
        log.info("loadPartitions {} is {}", partitionsName, rpcStatusR.getData().getMsg());
    }


    public void releasePartitions(String collectionName, List<String> partitionsName) {
        ReleasePartitionsParam build = ReleasePartitionsParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionNames(partitionsName)
                .build();
        R<RpcStatus> rpcStatusR = milvusClient.releasePartitions(build);
        log.info("releasePartition {} is {}", collectionName, rpcStatusR.getData().getMsg());
    }


    public boolean isExitCollection(String collectionName) {
        HasCollectionParam hasCollectionParam = HasCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .build();
        R<Boolean> response = milvusClient.hasCollection(hasCollectionParam);
        Boolean isExists = response.getData();
        log.info("collection {} is exists: {}", collectionName, isExists);
        return isExists;
    }


    public Boolean creatCollection(String collectionName) {
        // 主键字段
        FieldType fieldType1 = FieldType.newBuilder()
                .withName(Content.Field.ID)
                .withDescription("primary key")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(true)
                .build();
        // 文本字段
        FieldType fieldType2 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 向量字段
        FieldType fieldType3 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_VECTOR)
                .withDataType(DataType.FloatVector)
                .withDimension(Content.FEATURE_DIM)
                .build();
        // 创建collection
        CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .withDescription("Schema of Content")
                .withShardsNum(Content.SHARDS_NUM)
                .addFieldType(fieldType1)
                .addFieldType(fieldType2)
                .addFieldType(fieldType3)
                .build();
        R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);

        log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
        return response.getData().getMsg().equals("Success");
    }


    public Boolean dropCollection(String collectionName) {
        DropCollectionParam book = DropCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .build();
        R<RpcStatus> response = milvusClient.dropCollection(book);
        return response.getData().getMsg().equals("Success");
    }


    public void createPartition(String collectionName, String partitionName) {
        CreatePartitionParam param = CreatePartitionParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionName(partitionName)
                .build();
        R<RpcStatus> partition = milvusClient.createPartition(param);
        String msg = partition.getData().getMsg();
        log.info("create partition: {} in collection: {} is: {}", partition, collectionName, msg);
    }


    public Boolean createIndex(String collectionName) {
        // IndexType
        final IndexType INDEX_TYPE = IndexType.IVF_FLAT;
        // ExtraParam 建议值为 4 × sqrt(n), 其中 n 指 segment 最多包含的 entity 条数。
        final String INDEX_PARAM = "{\"nlist\":16384}";
        long startIndexTime = System.currentTimeMillis();
        R<RpcStatus> response = milvusClient.createIndex(CreateIndexParam.newBuilder()
                .withCollectionName(collectionName)
                .withIndexName(Content.CONTENT_INDEX)
                .withFieldName(Content.Field.CONTENT_VECTOR)
                .withMetricType(MetricType.L2)
                .withIndexType(INDEX_TYPE)
                .withExtraParam(INDEX_PARAM)
                .withSyncMode(Boolean.TRUE)
                .withSyncWaitingInterval(500L)
                .withSyncWaitingTimeout(30L)
                .build());
        long endIndexTime = System.currentTimeMillis();
        log.info("Succeed in " + (endIndexTime - startIndexTime) / 1000.00 + " seconds!");
        log.info("createIndex --->>> {} ", response.toString());
        GetIndexBuildProgressParam build = GetIndexBuildProgressParam.newBuilder()
                .withCollectionName(collectionName)
                .build();
        R<GetIndexBuildProgressResponse> idnexResp = milvusClient.getIndexBuildProgress(build);
        log.info("getIndexBuildProgress --->>> {}", idnexResp.getStatus());

        return response.getData().getMsg().equals("Success");
    }


    public ReplyMsg insert(String collectionName, List<InsertParam.Field> fields) {
        InsertParam insertParam = InsertParam.newBuilder()
                .withCollectionName(collectionName)
                .withFields(fields)
                .build();
        R<MutationResult> mutationResultR = milvusClient.insert(insertParam);

        log.info("Flushing...");
        long startFlushTime = System.currentTimeMillis();
        milvusClient.flush(FlushParam.newBuilder()
                .withCollectionNames(Collections.singletonList(collectionName))
                .withSyncFlush(true)
                .withSyncFlushWaitingInterval(50L)
                .withSyncFlushWaitingTimeout(30L)
                .build());
        long endFlushTime = System.currentTimeMillis();
        log.info("Succeed in " + (endFlushTime - startFlushTime) / 1000.00 + " seconds!");

        if (mutationResultR.getStatus() == 0){
            long insertCnt = mutationResultR.getData().getInsertCnt();
            log.info("Successfully! Total number of entities inserted: {} ", insertCnt);
            return ReplyMsg.ofSuccess("success", insertCnt);
        }
        log.error("InsertRequest failed!");
        return ReplyMsg.ofErrorMsg("InsertRequest failed!");
    }


    public List<List<SearchResultVo>> searchTopKSimilarity(SearchParamVo searchParamVo) {
        log.info("------search TopK Similarity------");
        SearchParam searchParam = SearchParam.newBuilder()
                .withCollectionName(searchParamVo.getCollectionName())
                .withMetricType(MetricType.L2)
                .withOutFields(searchParamVo.getOutputFields())
                .withTopK(searchParamVo.getTopK())
                .withVectors(searchParamVo.getQueryVectors())
                .withVectorFieldName(Content.Field.CONTENT_VECTOR)
                .withParams(searchParamVo.getParams())
                .build();

        R<SearchResults> respSearch = milvusClient.search(searchParam);
        if (respSearch.getData() == null) {
            return null;
        }
        log.info("------ process query results ------");
        SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
        List<List<SearchResultVo>> result = new ArrayList<>();
        for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
            List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
            List<SearchResultVo> list = new ArrayList<>();
            for (int j = 0; j < scores.size(); ++j) {
                SearchResultsWrapper.IDScore score = scores.get(j);
                QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
                long longID = score.getLongID();
                float distance = score.getScore();
                String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
                log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
                log.info("Content: " + content);
                list.add(SearchResultVo.builder().id(longID).score(distance).conent(content).build());
            }
            result.add(list);
        }
        log.info("Successfully!");

        return result;
    }

    public Boolean creatCollectionERP(String collectionName) {
        // 主键字段
        FieldType fieldType1 = FieldType.newBuilder()
                .withName(Content.Field.ID)
                .withDescription("primary key")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(true)
                .build();
        // 文本字段
        FieldType fieldType2 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 向量字段
        FieldType fieldType3 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_VECTOR)
                .withDataType(DataType.FloatVector)
                .withDimension(Content.FEATURE_DIM)
                .build();
        FieldType fieldType4 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_ANSWER)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType5 = FieldType.newBuilder()
                .withName(Content.Field.TITLE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType6 = FieldType.newBuilder()
                .withName(Content.Field.PARAM)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType7 = FieldType.newBuilder()
                .withName(Content.Field.TYPE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 创建collection
        CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .withDescription("Schema of Content ERP")
                .withShardsNum(Content.SHARDS_NUM)
                .addFieldType(fieldType1)
                .addFieldType(fieldType2)
                .addFieldType(fieldType3)
                .addFieldType(fieldType4)
                .addFieldType(fieldType5)
                .addFieldType(fieldType6)
                .addFieldType(fieldType7)
                .build();
        R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);

        log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
        return response.getData().getMsg().equals("Success");
    }

    public Boolean creatCollectionERPCLIP(String collectionName) {
        // 主键字段
        FieldType fieldType1 = FieldType.newBuilder()
                .withName(Content.Field.ID)
                .withDescription("primary key")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(true)
                .build();
        // 文本字段
        FieldType fieldType2 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 向量字段
        FieldType fieldType3 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_VECTOR)
                .withDataType(DataType.FloatVector)
                .withDimension(Content.FEATURE_DIM_CLIP)
                .build();
        FieldType fieldType4 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_ANSWER)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType5 = FieldType.newBuilder()
                .withName(Content.Field.TITLE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType6 = FieldType.newBuilder()
                .withName(Content.Field.PARAM)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType7 = FieldType.newBuilder()
                .withName(Content.Field.TYPE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType8 = FieldType.newBuilder()
                .withName(Content.Field.LABEL)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 创建collection
        CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .withDescription("Schema of Content ERP")
                .withShardsNum(Content.SHARDS_NUM)
                .addFieldType(fieldType1)
                .addFieldType(fieldType2)
                .addFieldType(fieldType3)
                .addFieldType(fieldType4)
                .addFieldType(fieldType5)
                .addFieldType(fieldType6)
                .addFieldType(fieldType7)
                .addFieldType(fieldType8)
                .build();
        R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);

        log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
        return response.getData().getMsg().equals("Success");
    }


    public Boolean creatCollectionERPNLP(String collectionName) {
        // 主键字段
        FieldType fieldType1 = FieldType.newBuilder()
                .withName(Content.Field.ID)
                .withDescription("primary key")
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(true)
                .build();
        // 文本字段
        FieldType fieldType2 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 向量字段
        FieldType fieldType3 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_VECTOR)
                .withDataType(DataType.FloatVector)
                .withDimension(Content.FEATURE_DIM_CLIP)
                .build();
        FieldType fieldType4 = FieldType.newBuilder()
                .withName(Content.Field.CONTENT_ANSWER)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType5 = FieldType.newBuilder()
                .withName(Content.Field.TITLE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType6 = FieldType.newBuilder()
                .withName(Content.Field.PARAM)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType7 = FieldType.newBuilder()
                .withName(Content.Field.TYPE)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        FieldType fieldType8 = FieldType.newBuilder()
                .withName(Content.Field.LABEL)
                .withDataType(DataType.VarChar)
                .withMaxLength(Content.MAX_LENGTH)
                .build();
        // 创建collection
        CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
                .withCollectionName(collectionName)
                .withDescription("Schema of Content ERP")
                .withShardsNum(Content.SHARDS_NUM)
                .addFieldType(fieldType1)
                .addFieldType(fieldType2)
                .addFieldType(fieldType3)
                .addFieldType(fieldType4)
                .addFieldType(fieldType5)
                .addFieldType(fieldType6)
                .addFieldType(fieldType7)
                .addFieldType(fieldType8)
                .build();
        R<RpcStatus> response = milvusClient.createCollection(createCollectionReq);

        log.info("collection: {} is created ? status: = {}", collectionName, response.getData().getMsg());
        return response.getData().getMsg().equals("Success");
    }

    public List<List<SearchERPResultVo>> searchERPTopKSimilarity(SearchERPParamVo searchParamVo) {
        log.info("------search ERP TopK Similarity------");
        SearchParam searchParam = SearchParam.newBuilder()
                .withCollectionName(searchParamVo.getCollectionName())
                .withMetricType(MetricType.L2)
                .withOutFields(searchParamVo.getOutputFields())
                .withTopK(searchParamVo.getTopK())
                .withVectors(searchParamVo.getQueryVectors())
                .withVectorFieldName(Content.Field.CONTENT_VECTOR)
                .withParams(searchParamVo.getParams())
                .build();

        R<SearchResults> respSearch = milvusClient.search(searchParam);
        if (respSearch.getData() == null) {
            return null;
        }
        log.info("------ process query results ------");
        SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
        List<List<SearchERPResultVo>> result = new ArrayList<>();
        for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
            List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
            List<SearchERPResultVo> list = new ArrayList<>();
            for (int j = 0; j < scores.size(); ++j) {
                SearchResultsWrapper.IDScore score = scores.get(j);
                QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
                long longID = score.getLongID();
                float distance = score.getScore();
                String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
                String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
                String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
                log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
                log.info("Content: " + content);
                list.add(SearchERPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
            }
            result.add(list);
        }
        log.info("Successfully!");

        return result;
    }


    public List<List<SearchNLPResultVo>> searchNLPTopKSimilarity(SearchNLPParamVo searchParamVo) {
        log.info("------search ERP TopK Similarity------");
        SearchParam searchParam = SearchParam.newBuilder()
                .withCollectionName(searchParamVo.getCollectionName())
                .withMetricType(MetricType.L2)
                .withOutFields(searchParamVo.getOutputFields())
                .withTopK(searchParamVo.getTopK())
                .withVectors(searchParamVo.getQueryVectors())
                .withVectorFieldName(Content.Field.CONTENT_VECTOR)
                .withParams(searchParamVo.getParams())
                .withExpr(searchParamVo.getExpr())
                .build();

        R<SearchResults> respSearch = milvusClient.search(searchParam);
        if (respSearch.getData() == null) {
            return null;
        }
        log.info("------ process query results ------");
        SearchResultsWrapper wrapper = new SearchResultsWrapper(respSearch.getData().getResults());
        List<List<SearchNLPResultVo>> result = new ArrayList<>();
        for (int i = 0; i < searchParamVo.getQueryVectors().size(); ++i) {
            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
            List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
            List<SearchNLPResultVo> list = new ArrayList<>();
            for (int j = 0; j < scores.size(); ++j) {
                SearchResultsWrapper.IDScore score = scores.get(j);
                QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
                long longID = score.getLongID();
                float distance = score.getScore();
                String content = (String) rowRecord.get(searchParamVo.getOutputFields().get(0));
                String contentAnswer = (String) rowRecord.get(searchParamVo.getOutputFields().get(1));
                String title = (String) rowRecord.get(searchParamVo.getOutputFields().get(2));
                log.info("Top " + j + " ID:" + longID + " Distance:" + distance);
                log.info("Content: " + content);
                list.add(SearchNLPResultVo.builder().id(longID).score(distance).content(content).contentAnswer(contentAnswer).title(title).build());
            }
            result.add(list);
        }
        log.info("Successfully!");

        return result;
    }
}

3、测试用例 

MilvusServiceERPNLPTest

@SpringBootTest(classes = {DataChatgptApplication.class}, webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class MilvusServiceERPNLPTest {

    @Autowired
    MivusService milvusService;
    @Autowired
    MilvusClient milvusClient;

    @Test
    void isExitCollection() {
        boolean mediumArticles = milvusService.isExitCollection(Content.COLLECTION_NAME_NLP);
        Assertions.assertTrue(mediumArticles);
    }


    @Test
    void creatCollection() {
        Boolean created = milvusService.creatCollectionERPNLP(Content.COLLECTION_NAME_NLP);
        Assertions.assertTrue(created);
    }

    @Test
    void createIndex(){
        Boolean index = milvusService.createIndex(Content.COLLECTION_NAME_NLP);
        Assertions.assertTrue(index);
    }


    @Test
    public void insertVector(){
        List<String> sentenceList = new ArrayList<>();
        sentenceList.add("网址是多少");
      


        List<String> contentAnswerList = new ArrayList<>();
        contentAnswerList.add("/home.ashx");
        


        List<String> titleList = new ArrayList<>();
        titleList.add("网址");
       



        List<String> paramList = new ArrayList<>();
        paramList.add("");
        


        List<String> typeList = new ArrayList<>();
        typeList.add("0");
        


        List<String> labelList = new ArrayList<>();
        labelList.add("操作直达");
       

        PaddleNewTextVo paddleNewTextVo = null;
        try {

            paddleNewTextVo = getVectorsLists(sentenceList);

            if (paddleNewTextVo == null) {
                // 获取不到再重试下
                paddleNewTextVo = getVectorsLists(sentenceList);
            }
            List<List<Double>> vectors = paddleNewTextVo.getVector();

            List<List<Float>> floatVectors = new ArrayList<>();

            for (List<Double> innerList : vectors) {
                List<Float> floatInnerList = new ArrayList<>();
                for (Double value : innerList) {
                    floatInnerList.add(value.floatValue());
                }
                floatVectors.add(floatInnerList);
            }


            // 2.准备插入向量数据库
            List<InsertParam.Field> fields = new ArrayList<>();
            fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
            fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, floatVectors));
            fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
            fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
            fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
            fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
            fields.add(new InsertParam.Field(Content.Field.LABEL, labelList));

            // 3.执行操作
            milvusService.insert(Content.COLLECTION_NAME_NLP, fields);

        } catch (ApiException e) {
            System.out.println(e.getMessage());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }

    }



    private static PaddleNewTextVo getVectorsLists(List<String> sentenceList) throws IOException {

        String url = "http://192.168.1.243:6001/";  //paddle
        URL obj = new URL(url);
        HttpURLConnection con = (HttpURLConnection) obj.openConnection();

        // 设置超时时间
        con.setConnectTimeout(50000);
        con.setReadTimeout(200000);
        con.setRequestMethod("POST");
        con.setRequestProperty("Content-Type", "application/json");
        con.setDoOutput(true);

        ObjectMapper objectParmMapper = new ObjectMapper();

        // 创建一个Map结构表示您的数据
        Map<String, List<Map<String, String>>> dataMap = new HashMap<>();
        dataMap.put("data", sentenceList.stream()
                .map(sentence -> Collections.singletonMap("text", sentence))
                .collect(Collectors.toList()));

        String jsonData = null;
        try {
            // 将Map转换为JSON字符串
            jsonData = objectParmMapper.writeValueAsString(dataMap);
        } catch (JsonProcessingException e) {
            System.err.println("Error converting to JSON: " + e.getMessage());
        }


        String data = jsonData;

        try(OutputStream os = con.getOutputStream()) {
            byte[] input = data.getBytes("utf-8");
            os.write(input, 0, input.length);
        }

        int responseCode = con.getResponseCode();
        System.out.println("Response Code: " + responseCode);

        PaddleNewTextVo paddleNewTextVo = null;

        if (responseCode == HttpURLConnection.HTTP_OK) { // 200表示成功
            BufferedReader in = new BufferedReader(new InputStreamReader(con.getInputStream()));
            String inputLine;
            StringBuilder content = new StringBuilder();

            while ((inputLine = in.readLine()) != null) {
                content.append(inputLine);
            }

            in.close();

            try {
                String contentStr = content.toString();

                // 直接解析JSON字符串到PaddleTextVo实例
                paddleNewTextVo = JSON.parseObject(contentStr, PaddleNewTextVo.class);


            } catch (Exception e) {
                System.err.println("Error parsing JSON: " + e.getMessage());
            }
        } else {
            System.out.println("Error Response Code: " + responseCode);
            BufferedReader errorReader = new BufferedReader(new InputStreamReader(con.getErrorStream()));
            String errorMessage;
            while ((errorMessage = errorReader.readLine()) != null) {
                System.out.println("Error Message: " + errorMessage);
            }
            errorReader.close();
        }
        return paddleNewTextVo;
    }






    @Test
    void searchTest(){
        // 0.加载向量集合
        milvusService.loadCollection(Content.COLLECTION_NAME_NLP);

        try {
            List<String> sentenceList = new ArrayList<>();
            sentenceList.add("XX列表");
            String label = "操作直达";

            // 1.获得向量
            // List<List<Float>> vectors = getVectorsLists(sentenceList);
            List<List<Float>> vectors = new ArrayList<>();

            SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
                    .collectionName(Content.COLLECTION_NAME_NLP)
                    .queryVectors(vectors)
                    .expr("label == '" + label + "'")
                    .topK(3)
                    .build();
            // 2.在向量数据库中进行搜索内容知识
            List<List<SearchNLPResultVo>> lists = milvusService.searchNLPTopKSimilarity(searchParamVo);
            lists.forEach(searchResultVos -> {
                searchResultVos.forEach(searchResultVo -> {
                    System.out.println(searchResultVo.getContent());
                    System.out.println(searchResultVo.getContentAnswer());
                    System.out.println(searchResultVo.getTitle());
                    System.out.println(searchResultVo.getLabel());
                });
            });

        } catch (ApiException e) {
            System.out.println(e.getMessage());
        } /*catch (IOException e) {
            throw new RuntimeException(e);
        }
*/
    }

    @Test
    public void insertTextVector() throws IOException {
        List<String> titleList = new ArrayList<>();
        List<String> sentenceList = new ArrayList<>();
        List<String> contentAnswerList = new ArrayList<>();
        List<String> paramList = new ArrayList<>();
        List<String> typeList = new ArrayList<>();


        String filePath = "src/main/resources/data/text.txt";

        try (BufferedReader reader = new BufferedReader(
                new InputStreamReader(new FileInputStream(filePath), StandardCharsets.UTF_8))) {
            // 使用4个竖线(||||)作为分隔符
            String line;
            while ((line = reader.readLine()) != null) {
                String[] parts = line.split("\\|\\|\\|\\|");

                if (parts.length >= 3) {
                    titleList.add(parts[0].trim());
                    sentenceList.add(parts[1].trim());
                    contentAnswerList.add(parts[2].trim());
                    paramList.add("");
                    typeList.add("2");
                } else {
                    System.out.println("Warning: Invalid format on line: " + line);
                }
            }

            // 打印或处理列表内容
            System.out.println("Title List: " + titleList);
            System.out.println("Sentence List: " + sentenceList);
            System.out.println("Content Answer List: " + contentAnswerList);

        } catch (IOException e) {
            System.err.println("Error reading file: " + e.getMessage());
        }


        try {
            // 1.获得向量
            TextEmbeddingParam param = TextEmbeddingParam
                    .builder()
                    .model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
                    .texts(sentenceList).build();
            TextEmbedding textEmbedding = new TextEmbedding();
            TextEmbeddingResult result = textEmbedding.call(param);

            List<List<Float>> vectors = new ArrayList<>();

            for (int i = 0; i < result.getOutput().getEmbeddings().size(); i++) {
                List<Double> vector = result.getOutput().getEmbeddings().get(i).getEmbedding();
                List<Float> floatVector = vector.stream()
                        .map(Double::floatValue)
                        .collect(Collectors.toList());
                vectors.add(floatVector);
            }


            // 2.准备插入向量数据库
            List<InsertParam.Field> fields = new ArrayList<>();
            fields.add(new InsertParam.Field(Content.Field.CONTENT, sentenceList));
            fields.add(new InsertParam.Field(Content.Field.CONTENT_VECTOR, vectors));
            fields.add(new InsertParam.Field(Content.Field.CONTENT_ANSWER, contentAnswerList));
            fields.add(new InsertParam.Field(Content.Field.TITLE, titleList));
            fields.add(new InsertParam.Field(Content.Field.PARAM, paramList));
            fields.add(new InsertParam.Field(Content.Field.TYPE, typeList));
            // 3.执行操作
            milvusService.insert(Content.COLLECTION_NAME_NLP, fields);

        } catch (ApiException | NoApiKeyException e) {
            System.out.println(e.getMessage());
        }
    }

    @Test
    void ChatBasedContentTest() throws NoApiKeyException, InputRequiredException, InterruptedException {
        // 0.加载向量集合
        milvusService.loadCollection(Content.COLLECTION_NAME_NLP);

        try {
            String question = "查询订单";
            List<String> sentenceList = new ArrayList<>();
            sentenceList.add(question);

            // 1.获得向量
            TextEmbeddingParam param = TextEmbeddingParam
                    .builder()
                    .model(TextEmbedding.Models.TEXT_EMBEDDING_V1)
                    .texts(sentenceList).build();
            TextEmbedding textEmbedding = new TextEmbedding();
            TextEmbeddingResult result = textEmbedding.call(param);

            List<Double> vector = result.getOutput().getEmbeddings().get(0).getEmbedding();

            List<Float> floatVector = vector.stream()
                    .map(Double::floatValue)
                    .collect(Collectors.toList());
            List<List<Float>> vectors = Collections.singletonList(floatVector);

            SearchERPParamVo searchParamVo = SearchERPParamVo.builder()
                    .collectionName(Content.COLLECTION_NAME_NLP)
                    .queryVectors(vectors)
                    .topK(3)
                    .build();
            // 2.在向量数据库中进行搜索内容知识
            StringBuffer buffer = new StringBuffer();
            List<List<SearchERPResultVo>> lists = milvusService.searchERPTopKSimilarity(searchParamVo);
            lists.forEach(searchResultVos -> {
                searchResultVos.forEach(searchResultVo -> {
                    buffer.append("问题: " + searchResultVo.getContent());
                    buffer.append("答案: " + searchResultVo.getContentAnswer());
                });
            });
            // 3.进行对话
            String prompt = "请你充分理解下面的内容,然后回答问题, 要求仅返回答案[]中内容:";
            String content = buffer.toString();
            String resultQwen = streamCallWithCallback(prompt + content + question);
            // System.out.println(resultQwen);

        } catch (ApiException | NoApiKeyException e) {
            System.out.println(e.getMessage());
        }

    }


    public static String streamCallWithCallback(String content)
            throws NoApiKeyException, ApiException, InputRequiredException,InterruptedException {
        Constants.apiKey="sk-2106098eed1f43c9bde754f3e87038a2";
        Generation gen = new Generation();
        Message userMsg = Message
                .builder()
                .role(Role.USER.getValue())
                .content(content)
                .build();
        QwenParam param = QwenParam
                .builder()
                .model(Generation.Models.QWEN_PLUS)
                .resultFormat(QwenParam.ResultFormat.MESSAGE)
                .messages(Arrays.asList(userMsg))
                .topP(0.8)
                .incrementalOutput(true) // get streaming output incrementally
                .build();
        Semaphore semaphore = new Semaphore(0);
        StringBuilder fullContent = new StringBuilder();
        gen.streamCall(param, new ResultCallback<GenerationResult>() {

            @Override
            public void onEvent(GenerationResult message) {
                fullContent.append(message.getOutput().getChoices().get(0).getMessage().getContent());
                System.out.println(message);
            }
            @Override
            public void onError(Exception err){
                System.out.println(String.format("Exception: %s", err.getMessage()));
                semaphore.release();
            }

            @Override
            public void onComplete(){
                System.out.println("Completed");
                semaphore.release();
            }

        });
        semaphore.acquire();
        System.out.println("Full content: \n" + fullContent.toString());
        return fullContent.toString();
    }

    @Test
    void loadData() throws IOException {
        // Read the dataset file
        String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
        // Load dataset
        JSONObject dataset = JSON.parseObject(content);
        List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 2);
        System.out.println(rows);
    }

    public String readFileToString(String filePath) throws IOException {
        return new String(Files.readAllBytes(Paths.get(filePath)), StandardCharsets.UTF_8);
    }
    public static List<JSONObject> getRows(JSONArray dataset, int counts) {
        List<JSONObject> rows = new ArrayList<>();
        for (int i = 0; i < counts; i++) {
            JSONObject row = dataset.getJSONObject(i);
            List<Float> vectors = row.getJSONArray("title_vector").toJavaList(Float.class);
            Long reading_time = row.getLong("reading_time");
            Long claps = row.getLong("claps");
            Long responses = row.getLong("responses");
            row.put("title_vector", vectors);
            row.put("reading_time", reading_time);
            row.put("claps", claps);
            row.put("responses", responses);
            row.remove("id");
            rows.add(row);
        }
        return rows;
    }

    @Test
    void getFileds() throws IOException {
        String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");

        // Load dataset
        JSONObject dataset = JSON.parseObject(content);
        List<InsertParam.Field> field = getFields(dataset.getJSONArray("rows"), 1);
        System.out.println(field);
    }

    public static List<InsertParam.Field> getFields(JSONArray dataset, int counts) {
        List<InsertParam.Field> fields = new ArrayList<>();
        List<String> titles = new ArrayList<>();
        List<List<Float>> title_vectors = new ArrayList<>();
        List<String> links = new ArrayList<>();
        List<Long> reading_times = new ArrayList<>();
        List<String> publications = new ArrayList<>();
        List<Long> claps_list = new ArrayList<>();
        List<Long> responses_list = new ArrayList<>();

        for (int i = 0; i < counts; i++) {
            JSONObject row = dataset.getJSONObject(i);
            titles.add(row.getString("title"));
            title_vectors.add(row.getJSONArray("title_vector").toJavaList(Float.class));
            links.add(row.getString("link"));
            reading_times.add(row.getLong("reading_time"));
            publications.add(row.getString("publication"));
            claps_list.add(row.getLong("claps"));
            responses_list.add(row.getLong("responses"));
        }

        fields.add(new InsertParam.Field("title", titles));
        fields.add(new InsertParam.Field("title_vector", title_vectors));
        fields.add(new InsertParam.Field("link", links));
        fields.add(new InsertParam.Field("reading_time", reading_times));
        fields.add(new InsertParam.Field("publication", publications));
        fields.add(new InsertParam.Field("claps", claps_list));
        fields.add(new InsertParam.Field("responses", responses_list));

        return fields;
    }


    @Test
    void searchTopKSimilarity() throws IOException {
        // Search data
        String content = readFileToString("src/main/resources/data/medium_articles_2020_dpr.json");
        // Load dataset
        JSONObject dataset = JSON.parseObject(content);
        List<JSONObject> rows = getRows(dataset.getJSONArray("rows"), 10);
        // You should include the following in the main function

        List<List<Float>> queryVectors = new ArrayList<>();
        List<Float> queryVector = rows.get(0).getJSONArray("title_vector").toJavaList(Float.class);
        queryVectors.add(queryVector);

        // Prepare the outputFields
        List<String> outputFields = new ArrayList<>();
        outputFields.add("title");
        outputFields.add("link");

        // Search vectors in a collection
        SearchParam searchParam = SearchParam.newBuilder()
                .withCollectionName("medium_articles")
                .withVectorFieldName("title_vector")
                .withVectors(queryVectors)
                .withExpr("claps > 30 and reading_time < 10")
                .withTopK(3)
                .withMetricType(MetricType.L2)
                .withParams("{\"nprobe\":10,\"offset\":2, \"limit\":3}")
                .withConsistencyLevel(ConsistencyLevelEnum.BOUNDED)
                .withOutFields(outputFields)
                .build();

        R<SearchResults> response = milvusClient.search(searchParam);

        SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
        System.out.println("Search results");
        for (int i = 0; i < queryVectors.size(); ++i) {
            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
            List<QueryResultsWrapper.RowRecord> rowRecords = wrapper.getRowRecords();
            for (int j = 0; j < scores.size(); ++j) {
                SearchResultsWrapper.IDScore score = scores.get(j);
                QueryResultsWrapper.RowRecord rowRecord = rowRecords.get(j);
                System.out.println("Top " + j + " ID:" + score.getLongID() + " Distance:" + score.getScore());
                System.out.println("Title: " + rowRecord.get("title"));
                System.out.println("Link: " + rowRecord.get("link"));
            }
        }

    }

   
    
}

4、查询

// 先根据向量查询语义相近的语料
List<Question> questionList = mivusService.searchNewPaddleQuestion(req.getMessage(), "1", appType);

**
     * 根据问题进行向量查询,采用Paddle服务 采用新的文本分类方法
     * @param question 用户的问题文本
     * @return 相关的初始问题知识列表
     */
    public List<Question> searchNewPaddleQuestion(String question, String type, String appType) {

        // 0.加载向量集合
        String collection = Content.COLLECTION_NAME_NLP;
        if (appType.equals("1")) {
            collection = Content.COLLECTION_NAME_NLP_APP;
        }
        loadCollection(collection);

        List<Question> resultList = new LinkedList<>();

        PaddleNewTextVo paddleNewTextVo = null;
        try {
            List<String> sentenceList = new ArrayList<>();
            sentenceList.add(question);

            // 1.获得向量
            paddleNewTextVo = getNewNLPVectorsLists(sentenceList);
            log.info("实时向量值 : {}", paddleNewTextVo.getPredictedList());

            List<List<Double>> vectors = paddleNewTextVo.getVector();

            List<List<Float>> floatVectors = new ArrayList<>();

            for (List<Double> innerList : vectors) {
                List<Float> floatInnerList = new ArrayList<>();
                for (Double value : innerList) {
                    floatInnerList.add(value.floatValue());
                }
                floatVectors.add(floatInnerList);
            }

            List<Integer> predictedList = paddleNewTextVo.getPredictedList();

            List<String> labelStrings = new ArrayList<>();
            HashSet<Integer> setType = new HashSet();
            int topK = 3;
            if(!predictedList.isEmpty()) {
                // 去重
                for (Integer number : predictedList) {
                    setType.add(number);
                    if (number == 2) {
                        // 如何是 2
                        topK = 1;
                    }
                }
                for (Integer label : setType) {
                    labelStrings.add("'" + label + "'");
                }
            }
            String typeResult = "[" + String.join(", ", labelStrings) + "]";

            SearchNLPParamVo searchParamVo = SearchNLPParamVo.builder()
                    .collectionName(collection)
                    //.expr("type == '" + type + "'")
                    .expr("type in ['0','1','2']")
                    //.expr("type in " + typeResult + " ")
                    .queryVectors(floatVectors)
                    .topK(topK)
                    .build();
            // 2.在向量数据库中进行搜索内容知识
            List<List<SearchNLPResultVo>> lists = searchNLPERPTopKSimilarity(searchParamVo);
            lists.forEach(searchResultVos -> {
                searchResultVos.forEach(searchResultVo -> {
                    log.info(searchResultVo.getContent());
                    log.info(searchResultVo.getContentAnswer());
                    Question question1 = new Question();
                    question1.setQuestionId(Long.valueOf(searchResultVo.getId()));
                    question1.setQuestion(searchResultVo.getContent());
                    question1.setAnswer(searchResultVo.getContentAnswer());
                    question1.setTitle(searchResultVo.getTitle());
                    question1.setParam(searchResultVo.getParam());
                    question1.setType(searchResultVo.getType());
                    question1.setLabel(searchResultVo.getLabel());
                    resultList.add(question1);
                });
            });

        } catch (ApiException | IOException e) {
            log.error(e.getMessage());
        }

        // 将查询到的结果转换为之前构造的 Question 的格式返回给前端
        return resultList;
    }

上一篇:检验周转率


下一篇:JVM学习-字节码指令集(一)