Sshuoshuo / easy-rag

快速入门RAG与私有化部署
128 stars 25 forks source link

gte召回 #3

Open YingxuanW opened 5 months ago

YingxuanW commented 5 months ago

佬你好,请问能提供一下gte召回方案的代码吗?

Sshuoshuo commented 5 months ago

可以参考官方脚本:https://huggingface.co/thenlper/gte-large-zh 或者代码:

def embed_documents(texts, tokenizer, batch_size=32, device='cuda'):
    num_texts = len(texts)
    texts = [t.replace("\n", " ") for t in texts]
    sentence_embeddings = []

    for start in tqdm(range(0, num_texts, batch_size)):
        end = min(start + batch_size, num_texts)
        batch_texts = texts[start:end]
        encoded_input = tokenizer(batch_texts, max_length=2048, padding=True, truncation=True,
                                       return_tensors='pt').to(device)

        with torch.no_grad():
            model_output = emb_model(**encoded_input)
            batch_embeddings = model_output.last_hidden_state[:, 0]
            batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
            sentence_embeddings.extend(batch_embeddings.tolist())

    return sentence_embeddings

def embed_query(text, tokenizer, device='cuda'):

    encoded_input = tokenizer([text], padding=True, max_length=2048,
                                       truncation=True, return_tensors='pt').to(device)
    with torch.no_grad():
        model_output = emb_model(**encoded_input)
        sentence_embeddings = model_output[0][:, 0]
    sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings[0].tolist()

def retrieval_with_docs_emb(query, all_embeddings, n=20):

    query_embedding = torch.FloatTensor(embed_query(query, tokenizer)).to(all_embeddings.device)
    cosine_scores = (all_embeddings * query_embedding).sum(1)
    top_sentences = (-cosine_scores).argsort()[: n]

    return top_sentences

EMB_MODEL_NAME_OR_PATH = " "
emb_model = AutoModel.from_pretrained(EMB_MODEL_NAME_OR_PATH, trust_remote_code=True).half().to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_NAME_OR_PATH, trust_remote_code=True)
passage_embs = torch.FloatTensor(embed_documents(passages, tokenizer)).to('cuda')

retrieval_res = retrieval_with_docs_emb(query, passage_embs, emb_model, n=50)