Open heyufeng666888 opened 1 year ago
` package com.fosafer.smart.utils;
import cn.hutool.core.util.NumberUtil; import com.google.common.util.concurrent.ListenableFuture; import io.milvus.client.MilvusClient; import io.milvus.client.MilvusServiceClient; import io.milvus.common.clientenum.ConsistencyLevelEnum; import io.milvus.grpc.; import io.milvus.param.; import io.milvus.param.collection.; import io.milvus.param.control.ManualCompactParam; import io.milvus.param.dml.DeleteParam; import io.milvus.param.dml.InsertParam; import io.milvus.param.dml.SearchParam; import io.milvus.param.index.; import io.milvus.param.partition.*; import io.milvus.response.DescCollResponseWrapper; import io.milvus.response.GetCollStatResponseWrapper; import io.milvus.response.SearchResultsWrapper; import lombok.extern.slf4j.Slf4j;
import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit;
/**
@author 贺宇峰 */ @Slf4j public class SearchUtil {
public static MilvusClient milvusClient = null;
public static void init(String host, Integer port) { if (milvusClient == null) { ConnectParam connectParam = ConnectParam.newBuilder() .withHost(host) .withPort(port) .build(); milvusClient = new MilvusServiceClient(connectParam); } }
/**
/**
/**
/**
private static void handleResponseStatus(R<?> r) { if (r.getStatus() != R.Status.Success.getCode()) { throw new RuntimeException(r.getMessage()); } }
/**
@return grpc返回值
*/
public static R
FieldType vectorField = FieldType.newBuilder() .withName(VECTOR_FIELD) .withDescription("声纹模型") .withDataType(DataType.FloatVector) .withDimension(VECTOR_DIM) .build();
CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder()
.withCollectionName(collectionName)
.withDescription(description)
.withShardsNum(2)
.addFieldType(idField)
.addFieldType(vectorField)
.build();
R
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
/**
@return 检索返回值
*/
public static SearchResultsWrapper search(String collectionName, MetricType metricType, int topK, List<List
R
/**
/**
@return 返回值
*/
public static R
InsertParam insertParam = InsertParam.newBuilder()
.withCollectionName(collectionName)
.withPartitionName(partitionName)
.withFields(fields)
.build();
ListenableFuture<R
String tableName = RepositoryConstant.getBKProfileTableName(String.valueOf(repository.getId()));
//不存在库则创建库
SearchUtil.createCollection(tableName, repository.getRepositoryId() + "_" + repository.getRepositoryName(), baseProperties.getSearchTimeout());
SearchUtil.createIndex(tableName, IndexType.valueOf(baseProperties.getSearchIndexType()), MetricType.valueOf(baseProperties.getSearchMetricType()),
SearchUtil.getIndexJsonParam(baseProperties.getSearchIndexType(), baseProperties.getIndexParam1(), baseProperties.getIndexParam2()));
SearchUtil.loadCollection(tableName);
int count = 0;
//while 获取库数据
List<BkProfileXxx> profileXxxes = profileXxxMapper.selectBkProfileXxx(tableName, RepositoryConstant.STATUS_VALID, 1, endTime, 0, baseProperties.getSearchPageSize());
while (CollectionUtil.isNotEmpty(profileXxxes)) {
List<List<BkProfileXxx>> profiles = CollectionUtil.split(profileXxxes, baseProperties.getSearchInsertSize());
profiles.forEach(bkProfileXxxes -> {
List<List<Float>> vectors = new ArrayList<>();
List<String> ids = new ArrayList<>();
for (BkProfileXxx profileXxx : bkProfileXxxes) {
ids.add(profileXxx.getUserid());
//vectors.add(JSON.parseObject(profileXxx.getModel(), new TypeReference<List<Float>>() {
//}));
vectors.add(DataCalcUtil.byteArrayToFloatList(profileXxx.getModel()));
}
SearchUtil.insertBatch(tableName, SearchUtil.DEFAULT_PARTITION, ids, vectors);
});
//如果小于pageSize 则结束
if (profileXxxes.size() < baseProperties.getSearchPageSize()) {
break;
}
++count;
profileXxxes = profileXxxMapper.selectBkProfileXxx(tableName, RepositoryConstant.STATUS_VALID, 1, endTime, count * baseProperties.getSearchPageSize(), baseProperties.getSearchPageSize());
}
} catch (Exception e) {
log.error("loadHistoryModel error", e);
}
`
` package com.fosafer.smart.utils;
import cn.hutool.core.util.NumberUtil; import com.google.common.util.concurrent.ListenableFuture; import io.milvus.client.MilvusClient; import io.milvus.client.MilvusServiceClient; import io.milvus.common.clientenum.ConsistencyLevelEnum; import io.milvus.grpc.; import io.milvus.param.; import io.milvus.param.collection.; import io.milvus.param.control.ManualCompactParam; import io.milvus.param.dml.DeleteParam; import io.milvus.param.dml.InsertParam; import io.milvus.param.dml.SearchParam; import io.milvus.param.index.; import io.milvus.param.partition.*; import io.milvus.response.DescCollResponseWrapper; import io.milvus.response.GetCollStatResponseWrapper; import io.milvus.response.SearchResultsWrapper; import lombok.extern.slf4j.Slf4j;
import java.math.RoundingMode; import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit;
/**
@author 贺宇峰 */ @slf4j public class SearchUtil { public static MilvusClient milvusClient = null; public static void init(String host, Integer port) { if (milvusClient == null) { ConnectParam connectParam = ConnectParam.newBuilder() .withHost(host) .withPort(port) .build(); milvusClient = new MilvusServiceClient(connectParam); } } /**
- 库名 _/ public static final String COLLECTIONNAME = "surveillance"; /*
- 主键 _/ private static final String IDFIELD = "id"; /*
- 模型 _/ private static final String VECTORFIELD = "model"; /*
- 向量维度 _/ private static final Integer VECTORDIM = 512; /*
- 向量索引 _/ public static final String INDEX_NAME = "modelindex"; /*
- 默认分区 */ public static final String DEFAULT_PARTITION = "_default";
/**
获取索引检索json参数
@param indexType 索引类型
@param param 参数
@return 公制json参数 */ public static String getSearchJsonParam(String indexType, Integer... param) { if (IndexType.FLAT.name().equals(indexType) || IndexType.IVF_FLAT.name().equals(indexType) || IndexType.IVF_SQ8.name().equals(indexType)) { return "{"nprobe":" + param[0] + "}"; }else if (IndexType.HNSW.name().equals(indexType)) { return "{"ef":" + param[0] + "}"; } return ""; }
/**
获取索引json参数
@param indexType 索引类型
@param param 参数
@return 索引json参数 */ public static String getIndexJsonParam(String indexType, Integer... param) { if (IndexType.FLAT.name().equals(indexType) || IndexType.IVF_FLAT.name().equals(indexType) || IndexType.IVF_SQ8.name().equals(indexType)) { return "{"nlist":" + param[0] + "}"; } else if (IndexType.HNSW.name().equals(indexType)) { return "{"M": " + param[0] + ", "efConstruction": " + param[1] + "}"; } return ""; }
/**
获取分数
@param metricType 公制类型
@param score 分数
@return 分数 / public static Float getScore(String metricType, Float score) { if (MetricType.IP.name().equals(metricType)) { return NumberUtil.round(String.valueOf((score + 1f) 0.5f), 2, RoundingMode.HALF_UP).floatValue(); } else if (MetricType.L2.name().equals(metricType)) { return NumberUtil.round(String.valueOf(1f - score / 4f), 2, RoundingMode.HALF_UP).floatValue(); } return 0f; }
private static void handleResponseStatus(R<?> r) { if (r.getStatus() != R.Status.Success.getCode()) { throw new RuntimeException(r.getMessage()); } } /**
创建库
@param collectionName 库名
@param description 描述
@param timeoutMilliseconds 超时时间
@return grpc返回值 */ public static R createCollection(String collectionName, String description, long timeoutMilliseconds) { log.info("SearchUtil.createCollection collectionName {} description {}", collectionName, description); FieldType idField = FieldType.newBuilder() .withName(ID_FIELD) .withDescription("主键id") .withDataType(DataType.VarChar) .withMaxLength(32) .withPrimaryKey(true) .build(); FieldType vectorField = FieldType.newBuilder() .withName(VECTOR_FIELD) .withDescription("声纹模型") .withDataType(DataType.FloatVector) .withDimension(VECTOR_DIM) .build(); CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder() .withCollectionName(collectionName) .withDescription(description) .withShardsNum(2) .addFieldType(idField) .addFieldType(vectorField) .build(); R response = milvusClient.withTimeout(timeoutMilliseconds, TimeUnit.MILLISECONDS) .createCollection(createCollectionReq); handleResponseStatus(response); log.info("SearchUtil.createCollection {}", response); return response; }
/**
删除库
@param collectionName 库名
@return grpc返回值 */ public static R dropCollection(String collectionName) { log.info("SearchUtil.dropCollection collectionName {}", collectionName); R response = milvusClient.dropCollection(DropCollectionParam.newBuilder() .withCollectionName(collectionName) .build()); log.info("SearchUtil.dropCollection {}", response); return response; }
/**
是否存在库
@param collectionName 库名
@return grpc返回值 */ public static boolean hasCollection(String collectionName) { log.info("SearchUtil.hasCollection collectionName {}", collectionName); R response = milvusClient.hasCollection(HasCollectionParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); log.info("SearchUtil.hasCollection {}", response); return response.getData(); }
/**
加载库模型
@param collectionName 库名
@return grpc返回值 */ public static R loadCollection(String collectionName) { log.info("SearchUtil.loadCollection collectionName {}", collectionName); R response = milvusClient.loadCollection(LoadCollectionParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); log.info("SearchUtil.loadCollection {}", response); return response; }
/**
释放库模型
@param collectionName 库名
@return grpc返回值 */ public static R releaseCollection(String collectionName) { log.info("SearchUtil.releaseCollection collectionName {}", collectionName); R response = milvusClient.releaseCollection(ReleaseCollectionParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); log.info("SearchUtil.releaseCollection {}", response); return response; }
/**
获取库定义
@param collectionName 库名
@return 库定义 */ public static R describeCollection(String collectionName) { log.info("SearchUtil.describeCollection collectionName {}", collectionName); R response = milvusClient.describeCollection(DescribeCollectionParam.newBuilder() .withCollectionName(COLLECTION_NAME) .build()); handleResponseStatus(response); DescCollResponseWrapper wrapper = new DescCollResponseWrapper(response.getData()); log.info("SearchUtil.describeCollection {}", wrapper); return response; }
/**
获取库统计
@param collectionName 库名
@return 库统计 */ public static R getCollectionStatistics(String collectionName) { // call flush() to flush the insert buffer to storage, // so that the getCollectionStatistics() can get correct number log.info("SearchUtil.getCollectionStatistics collectionName {}", collectionName); milvusClient.flush(FlushParam.newBuilder().addCollectionName(collectionName).build()); R response = milvusClient.getCollectionStatistics( GetCollectionStatisticsParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); GetCollStatResponseWrapper wrapper = new GetCollStatResponseWrapper(response.getData()); log.info("SearchUtil.getCollectionStatistics {}", wrapper); return response; }
/**
展示库列表
@return 库列表 */ public static R showCollections() { log.info("SearchUtil.showCollections"); if (milvusClient == null) { throw new RuntimeException("milvusClient init error"); } R response = milvusClient.showCollections(ShowCollectionsParam.newBuilder() .build()); handleResponseStatus(response); log.info("SearchUtil.showCollections {}", response); return response; }
/**
创建库子分区
@param collectionName 库名
@param partitionName 分区名
@return grpc返回值 */ public static R createPartition(String collectionName, String partitionName) { log.info("SearchUtil.createPartition collectionName {} partitionName {}", collectionName, partitionName); R response = milvusClient.createPartition(CreatePartitionParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .build()); handleResponseStatus(response); log.info("SearchUtil.createPartition {}", response); return response; }
/**
删除库分区
@param collectionName 库名
@param partitionName 分区名
@return 返回值 */ public static R dropPartition(String collectionName, String partitionName) { log.info("SearchUtil.dropPartition collectionName {} partitionName {}", collectionName, partitionName); R response = milvusClient.dropPartition(DropPartitionParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .build()); handleResponseStatus(response); log.info("SearchUtil.dropPartition {}", response); return response; }
/**
是否存在分区
@param collectionName 库名
@param partitionName 分区名
@return 是否 */ public static R hasPartition(String collectionName, String partitionName) { log.info("SearchUtil.hasPartition collectionName {} partitionName {}", collectionName, partitionName); R response = milvusClient.hasPartition(HasPartitionParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .build()); handleResponseStatus(response); log.info("SearchUtil.hasPartition {}", response); return response; }
/**
加载分区模型
@param collectionName 库名
@param partitionName 分区名
@return grpc返回值 */ public static R loadPartition(String collectionName, String partitionName) { log.info("SearchUtil.loadPartition collectionName {} partitionName {}", collectionName, partitionName); R response = milvusClient.loadPartitions(LoadPartitionsParam.newBuilder() .withCollectionName(collectionName) .addPartitionName(partitionName) .build()); handleResponseStatus(response); log.info("SearchUtil.loadPartition {}", response); return response; }
/**
释放分区模型
@param collectionName 库名
@param partitionName 分区名
@return grpc返回值 */ public static R releasePartition(String collectionName, String partitionName) { log.info("SearchUtil.releasePartition collectionName {} partitionName {}", collectionName, partitionName); R response = milvusClient.releasePartitions(ReleasePartitionsParam.newBuilder() .withCollectionName(collectionName) .addPartitionName(partitionName) .build()); handleResponseStatus(response); log.info("SearchUtil.releasePartition {}", response); return response; }
/**
展示分区列表
@param collectionName 库名
@return 分区列表 */ public static R showPartitions(String collectionName) { log.info("SearchUtil.showPartitions collectionName {}", collectionName); R response = milvusClient.showPartitions(ShowPartitionsParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); log.info("SearchUtil.showPartitions {}", response); return response; }
/**
创建索引
@param collectionName 库名
@param indexType 索引类型
@param metricType 公制类型
@param indexJsonParam 索引json参数
@return grpc返回值 */ public static R createIndex(String collectionName, IndexType indexType, MetricType metricType, String indexJsonParam) { log.info("SearchUtil.createIndex collectionName {} indexType {} metricType {} indexJsonParam {}", collectionName, indexType.name(), metricType.name(), indexJsonParam); R response = milvusClient.createIndex(CreateIndexParam.newBuilder() .withCollectionName(collectionName) .withFieldName(VECTOR_FIELD) .withIndexName(INDEX_NAME) .withIndexType(indexType) .withMetricType(metricType) .withExtraParam(indexJsonParam) .withSyncMode(Boolean.FALSE) .build()); handleResponseStatus(response); log.info("SearchUtil.createIndex {}", response); return response; }
/**
删除索引
@param collectionName 库名
@return grpc返回值 */ public static R dropIndex(String collectionName) { log.info("SearchUtil.dropIndex collectionName {}", collectionName); R response = milvusClient.dropIndex(DropIndexParam.newBuilder() .withCollectionName(collectionName) .withIndexName(INDEX_NAME) .build()); handleResponseStatus(response); log.info("SearchUtil.createIndex {}", response); return response; }
/**
过去索引定义
@param collectionName 库名
@return 索引定义 */ public static R describeIndex(String collectionName) { log.info("SearchUtil.describeIndex collectionName {}", collectionName); R response = milvusClient.describeIndex(DescribeIndexParam.newBuilder() .withCollectionName(collectionName) .withIndexName(INDEX_NAME) .build()); handleResponseStatus(response); log.info("SearchUtil.describeIndex {}", response); return response; }
/**
获取索引状态
@param collectionName 库名
@return 索引状态 */ public static R getIndexState(String collectionName) { log.info("SearchUtil.getIndexState collectionName {}", collectionName); R response = milvusClient.getIndexState(GetIndexStateParam.newBuilder() .withCollectionName(collectionName) .withIndexName(INDEX_NAME) .build()); handleResponseStatus(response); log.info("SearchUtil.getIndexState {}", response); return response; }
/**
获取索引构建状态
@param collectionName 库名
@return 索引构建状态 */ public static R getIndexBuildProgress(String collectionName) { log.info("SearchUtil.getIndexBuildProgress collectionName {}", collectionName); R response = milvusClient.getIndexBuildProgress( GetIndexBuildProgressParam.newBuilder() .withCollectionName(collectionName) .withIndexName(INDEX_NAME) .build()); handleResponseStatus(response); log.info("SearchUtil.getIndexBuildProgress {}", response); return response; }
/**
删除
@param collectionName 库名
@param partitionName 分区名
@param expr 表达式
@return 操作返回值 */ public static R delete(String collectionName, String partitionName, String expr) { log.info("SearchUtil.delete collectionName {} partitionName {} expr {}", collectionName, partitionName, expr); DeleteParam build = DeleteParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .withExpr(expr) .build(); try { R response = milvusClient.delete(build); handleResponseStatus(response); log.info("SearchUtil.delete {}", response); return response; } catch (Exception e) { log.error("delete error", e); } return null; }
/**
向量检索
@param collectionName 库名
@param metricType 公制类型
@param topK 前几
@param vectors 模型集合
@param searchJsonParam 检索json参数
@return 检索返回值 */ public static SearchResultsWrapper search(String collectionName, MetricType metricType, int topK, List
vectors, String searchJsonParam) { log.debug("SearchUtil.search collectionName {} metricType {} searchJsonParam {} topK {} vectors.size {}", collectionName, metricType.name(), searchJsonParam, topK, vectors.size()); long begin = System.currentTimeMillis(); SearchParam searchParam = SearchParam.newBuilder() .withCollectionName(collectionName) .withMetricType(metricType) .withTopK(topK) .withVectors(vectors) .withIgnoreGrowing(true) .withVectorFieldName(VECTOR_FIELD) .withParams(searchJsonParam) .withConsistencyLevel(ConsistencyLevelEnum.EVENTUALLY) .withGuaranteeTimestamp(Constant.GUARANTEE_EVENTUALLY_TS) .build(); R response = milvusClient.search(searchParam); log.debug("SearchUtil.search time cost {} ms", System.currentTimeMillis() - begin); handleResponseStatus(response); log.debug("SearchUtil.search {}", response); return new SearchResultsWrapper(response.getData().getResults()); }
/**
压缩
@param collectionName 库名
@return 压缩返回值 */ public static R compact(String collectionName) { log.info("SearchUtil.compact collectionName {}", collectionName); R response = milvusClient.manualCompact(ManualCompactParam.newBuilder() .withCollectionName(collectionName) .build()); handleResponseStatus(response); log.info("SearchUtil.compact {}", response); return response; }
/**
插入数据
@param collectionName 库名
@param partitionName 分区名
@param ids 主键
@param vectors 模型
@return 返回值 */ public static R insertBatch(String collectionName, String partitionName, List ids, List
vectors) { log.debug("SearchUtil.insertBatch collectionName {} partitionName {} ids.size {} vectors.size {}", collectionName, partitionName, ids.size(), vectors.size()); List
fields = new ArrayList<>(); fields.add(new InsertParam.Field(ID_FIELD, ids)); fields.add(new InsertParam.Field(VECTOR_FIELD, vectors)); InsertParam insertParam = InsertParam.newBuilder() .withCollectionName(collectionName) .withPartitionName(partitionName) .withFields(fields) .build(); ListenableFuture future = milvusClient.insertAsync(insertParam); try { R response = future.get(); handleResponseStatus(response); return response; } catch (Exception e) { log.error("insertBatch error", e); } return null; } } String tableName = RepositoryConstant.getBKProfileTableName(String.valueOf(repository.getId())); //不存在库则创建库 SearchUtil.createCollection(tableName, repository.getRepositoryId() + "_" + repository.getRepositoryName(), baseProperties.getSearchTimeout()); SearchUtil.createIndex(tableName, IndexType.valueOf(baseProperties.getSearchIndexType()), MetricType.valueOf(baseProperties.getSearchMetricType()), SearchUtil.getIndexJsonParam(baseProperties.getSearchIndexType(), baseProperties.getIndexParam1(), baseProperties.getIndexParam2())); SearchUtil.loadCollection(tableName); int count = 0; //while 获取库数据 List<BkProfileXxx> profileXxxes = profileXxxMapper.selectBkProfileXxx(tableName, RepositoryConstant.STATUS_VALID, 1, endTime, 0, baseProperties.getSearchPageSize()); while (CollectionUtil.isNotEmpty(profileXxxes)) { List<List<BkProfileXxx>> profiles = CollectionUtil.split(profileXxxes, baseProperties.getSearchInsertSize()); profiles.forEach(bkProfileXxxes -> { List<List<Float>> vectors = new ArrayList<>(); List<String> ids = new ArrayList<>(); for (BkProfileXxx profileXxx : bkProfileXxxes) { ids.add(profileXxx.getUserid()); //vectors.add(JSON.parseObject(profileXxx.getModel(), new TypeReference<List<Float>>() { //})); vectors.add(DataCalcUtil.byteArrayToFloatList(profileXxx.getModel())); } SearchUtil.insertBatch(tableName, SearchUtil.DEFAULT_PARTITION, ids, vectors); }); //如果小于pageSize 则结束 if (profileXxxes.size() < baseProperties.getSearchPageSize()) { break; } ++count; profileXxxes = profileXxxMapper.selectBkPrdofileXxx(tableName, RepositoryConstant.STATUS_VALID, 1, endTime, count * baseProperties.getSearchPageSize(), baseProperties.getSearchPageSize()); } } catch (Exception e) { log.error("loadHistoryModel error", e); }
`
Hi @heyufeng666888 Yes you are right, the larger the M, the more memory it consumes, the longer the index time. ef construction should has some impact on the index time too.
If you can share some of your benchmark result that would be great.
Also you don't really need to do the pagination yourself. Milvus support offset, limit already
The larger the hnsw M, the longer the index construction time, and efConstruction has no impact,please check