FlagOpen / FlagEmbedding

Retrieval and Retrieval-augmented LLMs
MIT License
6.81k stars 491 forks source link

bge-m3参数max_passage_length默认是8192速度很慢 #434

Open huangtingwei9988 opened 7 months ago

huangtingwei9988 commented 7 months ago

A100测试

code:

import time
from FlagEmbedding import BGEM3FlagModel

model = BGEM3FlagModel('/home/admin/bge-m3',  use_fp16=True) 

sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = ["BGE M3 is an embedding model supporting dense retrieval, lexical matching and multi-vector interaction.", 
               "BM25 is a bag-of-words retrieval function that ranks a set of documents based on the query terms appearing in each document"]

sentence_pairs = [[i,j] for i in sentences_1 for j in sentences_2]

for i in range(5):
    start_time = time.time()
    model.compute_score(sentence_pairs, max_passage_length=8192)
    cost_time = time.time() - start_time
    print(f"max_passage_length=8192 cost_time: {cost_time}")

for i in range(5):
    start_time = time.time()
    model.compute_score(sentence_pairs, max_passage_length=2048)
    cost_time = time.time() - start_time
    print(f"max_passage_length=2048 cost_time: {cost_time}")

for i in range(5):
    start_time = time.time()
    model.compute_score(sentence_pairs, max_passage_length=1024)
    cost_time = time.time() - start_time
    print(f"max_passage_length=1024 cost_time: {cost_time}")

结果: `max_passage_length=8192 cost_time: 2.3674864768981934

max_passage_length=8192 cost_time: 1.705240249633789

max_passage_length=8192 cost_time: 1.704312801361084

max_passage_length=8192 cost_time: 1.7043406963348389

max_passage_length=8192 cost_time: 1.7050418853759766

max_passage_length=2048 cost_time: 0.17179536819458008

max_passage_length=2048 cost_time: 0.17013168334960938

max_passage_length=2048 cost_time: 0.16979646682739258

max_passage_length=2048 cost_time: 0.17057323455810547

max_passage_length=2048 cost_time: 0.17035984992980957

max_passage_length=1024 cost_time: 0.07496500015258789

max_passage_length=1024 cost_time: 0.07479047775268555

max_passage_length=1024 cost_time: 0.07449865341186523

max_passage_length=1024 cost_time: 0.07444643974304199

max_passage_length=1024 cost_time: 0.07444238662719727`

hanhainebula commented 7 months ago

你好,方便贴一下你的测试代码吗?

huangtingwei9988 commented 7 months ago

你好,方便贴一下你的测试代码吗?

已经修改comment

hanhainebula commented 7 months ago

谢谢你发现这个问题!我们找到了原因,是因为 compute_score 中的 _tokenize 函数把输入 padding 到设置的 max_length 了,这里应该设置成 padding=True,按照输入的最大长度进行 padding。目前已经更新了代码

现在上面的测试代码在 1 张 A800 GPU 上的测试结果如下:

max_passage_length=8192 cost_time: 0.5043973922729492 max_passage_length=8192 cost_time: 0.028858423233032227 max_passage_length=8192 cost_time: 0.028004169464111328 max_passage_length=8192 cost_time: 0.027806520462036133 max_passage_length=8192 cost_time: 0.028023481369018555 max_passage_length=2048 cost_time: 0.027927637100219727 max_passage_length=2048 cost_time: 0.027832508087158203 max_passage_length=2048 cost_time: 0.029221296310424805 max_passage_length=2048 cost_time: 0.02766108512878418 max_passage_length=2048 cost_time: 0.02771782875061035 max_passage_length=1024 cost_time: 0.027782440185546875 max_passage_length=1024 cost_time: 0.027876853942871094 max_passage_length=1024 cost_time: 0.027625560760498047 max_passage_length=1024 cost_time: 0.02752852439880371 max_passage_length=1024 cost_time: 0.02778339385986328