Open YingxuanW opened 5 months ago
可以参考官方脚本:https://huggingface.co/thenlper/gte-large-zh 或者代码:
def embed_documents(texts, tokenizer, batch_size=32, device='cuda'):
num_texts = len(texts)
texts = [t.replace("\n", " ") for t in texts]
sentence_embeddings = []
for start in tqdm(range(0, num_texts, batch_size)):
end = min(start + batch_size, num_texts)
batch_texts = texts[start:end]
encoded_input = tokenizer(batch_texts, max_length=2048, padding=True, truncation=True,
return_tensors='pt').to(device)
with torch.no_grad():
model_output = emb_model(**encoded_input)
batch_embeddings = model_output.last_hidden_state[:, 0]
batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
sentence_embeddings.extend(batch_embeddings.tolist())
return sentence_embeddings
def embed_query(text, tokenizer, device='cuda'):
encoded_input = tokenizer([text], padding=True, max_length=2048,
truncation=True, return_tensors='pt').to(device)
with torch.no_grad():
model_output = emb_model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings[0].tolist()
def retrieval_with_docs_emb(query, all_embeddings, n=20):
query_embedding = torch.FloatTensor(embed_query(query, tokenizer)).to(all_embeddings.device)
cosine_scores = (all_embeddings * query_embedding).sum(1)
top_sentences = (-cosine_scores).argsort()[: n]
return top_sentences
EMB_MODEL_NAME_OR_PATH = " "
emb_model = AutoModel.from_pretrained(EMB_MODEL_NAME_OR_PATH, trust_remote_code=True).half().to(DEVICE).eval()
tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_NAME_OR_PATH, trust_remote_code=True)
passage_embs = torch.FloatTensor(embed_documents(passages, tokenizer)).to('cuda')
retrieval_res = retrieval_with_docs_emb(query, passage_embs, emb_model, n=50)
佬你好,请问能提供一下gte召回方案的代码吗?