FlagOpen / FlagEmbedding

Retrieval and Retrieval-augmented LLMs
MIT License
7.38k stars 533 forks source link

bge-reranker-v2-m3模型,gpu资源利用率高,显存占用高,并且很久也没有输出结果 #918

Open xiaoToby opened 4 months ago

xiaoToby commented 4 months ago

我使用了该文档调用了bge-reranker-v2-m3模型

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time: 2023/11/7 22:45
@Author: zhidong
@File: reranker.py
@Desc:
"""
import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from FlagEmbedding import FlagReranker
from pydantic import Field, BaseModel, validator
from typing import Optional, List

app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'

class QADocs(BaseModel):
    query: Optional[str]
    documents: Optional[List[str]]

class Singleton(type):
    def __call__(cls, *args, **kwargs):
        if not hasattr(cls, '_instance'):
            cls._instance = super().__call__(*args, **kwargs)
        return cls._instance

RERANK_MODEL_PATH = os.path.join(os.path.dirname(__file__), "bge-reranker-v2-m3")

class ReRanker(metaclass=Singleton):
    def __init__(self, model_path):
        self.reranker = FlagReranker(model_path, use_fp16=False)    

    def compute_score(self, pairs: List[List[str]]):
        if len(pairs) > 0:
            result = self.reranker.compute_score(pairs, normalize=True)
            if isinstance(result, float):
                result = [result]
            return result
        else:
            return None

class Chat(object):
    def __init__(self, rerank_model_path: str = RERANK_MODEL_PATH):
        self.reranker = ReRanker(rerank_model_path)

    def fit_query_answer_rerank(self, query_docs: QADocs) -> List:
        if query_docs is None or len(query_docs.documents) == 0:
            return []

        pair = [[query_docs.query, doc] for doc in query_docs.documents]
        scores = self.reranker.compute_score(pair)

        new_docs = []
        for index, score in enumerate(scores):
            new_docs.append({"index": index, "text": query_docs.documents[index], "score": score})
        results = [{"index": documents["index"], "relevance_score": documents["score"]} for documents in list(sorted(new_docs, key=lambda x: x["score"], reverse=True))]
        return results

@app.post('/v1/rerank')
async def handle_post_request(docs: QADocs, credentials: HTTPAuthorizationCredentials = Security(security)):
    token = credentials.credentials
    if env_bearer_token is not None and token != env_bearer_token:
        raise HTTPException(status_code=401, detail="Invalid token")
    chat = Chat()
    try:
        results = chat.fit_query_answer_rerank(docs)
        return {"results": results}
    except Exception as e:
        print(f"报错:\n{e}")
        return {"error": "重排出错"}

if __name__ == "__main__":
    token = os.getenv("ACCESS_TOKEN")
    if token is not None:
        env_bearer_token = token
    try:
        uvicorn.run(app, host='0.0.0.0', port=6006)
    except Exception as e:
        print(f"API启动失败!\n报错:\n{e}")

使用该测试文档测试模型:

import requests

url = f"http://localhost:7013/v1/rerank"

headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer sk-tidukjjinarerank"
}

data = {
    "model": "bge-reranker-v2-m3",
    "query": "Organic skincare products for sensitive skin",
    "documents": [
        "Eco-friendly kitchenware for modern homes",
        "Biodegradable cleaning supplies for eco-conscious consumers",
        "Organic cotton baby clothes for sensitive skin",
        "Natural organic skincare range for sensitive skin",
        "Tech gadgets for smart homes: 2024 edition",
        "Sustainable gardening tools and compost solutions",
        "Sensitive skin-friendly facial cleansers and toners",
        "Organic food wraps and storage solutions",
        "All-natural pet food for dogs with allergies",
        "Yoga mats made from recycled materials"
    ],
    "top_n": 3
}

response = requests.post(url, headers=headers, json=data)
print(response.json())

GPU资源使用情况截图: image

现在就是一直这样卡着: image

staoxiao commented 4 months ago

抱歉,没有使用过FastAPI,。建议到FastAPI repo提问

xiaoToby commented 4 months ago

现在能运行了,加上了环境变量 cuda_visiable_device

但是现在还有一个问题,因为这是一个api,但是每次运行的时候都会要j如进行图的操作: image

应该如何提前配置好trust_remote_code=True? @staoxiao

staoxiao commented 4 months ago

bge-reranker-v2-m3不需要trust_remote_code参数,这个是jina-reranker需要的参数。 你可以参考jina官方的做法调用其reranker

xiaoToby commented 4 months ago

建议在flag_reranker.py中159行 self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir, trust_remote_code=True)