Closed charliedream1 closed 7 months ago
你好,请尝试在推理前添加model.eval()
加了的
---原始邮件--- 发件人: @.> 发送时间: 2024年4月8日(周一) 晚上9:11 收件人: @.>; 抄送: "Optimus @.**@.>; 主题: Re: [FlagOpen/FlagEmbedding] bge-visual每次出来的embed值都不一样 (Issue #659)
你好,请尝试在推理前添加model.eval()
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>
你好,我这边测试每次结果都是一样的,能否提供您的运行代码?
这是我使用bge-m3的测试代码,您可以参考一下
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
model = Visualized_BGE(model_name_bge = "BAAI/bge-m3", model_weight="your path")
model.eval()
with torch.no_grad():
query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6303]]) tensor([[0.3493]]) tensor([[0.5599]])
print(candi_emb_1) # tensor([[ 0.0120, 0.0005, -0.0165, ..., -0.0078, 0.0148, 0.0139]])
####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE
bge_m3_mdl_path = '/data/BAAI/models--BAAI--bge-m3/snapshots/73a15ad29ab604f3bdc31601849a9defe86d563f'
visual_mdl_path = '/data/BAAI/models--BAAI--bge-visualized/snapshots/98db10b10d22620010d06f11733346e1c98c34aa/Visualized_m3.pth'
model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path)
with torch.no_grad():
query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")
sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])
运行代码如上。另外修改了一下FlagEmbedding/visual/modeling.py
from torch import nn, Tensor
class Visualized_BGE(nn.Module):
def __init__(self,
model_name_bge: str = None,
model_weight = None, # "/path/to/your/weight/file/"
normlized: bool = True,
sentence_pooling_method: str = 'cls',
negatives_cross_device: bool = False,
temperature: float = 0.02, # 1.0
):
super().__init__()
# assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"]
assert model_weight is not None
name_flag = False
for name in ["bge-base-en-v1.5", "bge-m3"]:
if name in model_name_bge:
name_flag = True
if not name_flag:
raise ValueError(f"model_name_bge should not be {name}")
self.model_name_bge = model_name_bge
# if model_name_bge == 'BAAI/bge-base-en-v1.5':
# model_name_eva = "EVA02-CLIP-B-16"
# self.hidden_dim = 768
# self.depth = 12
# elif model_name_bge == 'BAAI/bge-m3':
# model_name_eva = "EVA02-CLIP-L-14"
# self.hidden_dim = 1024
# self.depth = 24
if 'bge-base-en-v1.5' in model_name_bge:
model_name_eva = "EVA02-CLIP-B-16"
self.hidden_dim = 768
self.depth = 12
elif 'bge-m3' in model_name_bge:
model_name_eva = "EVA02-CLIP-L-14"
self.hidden_dim = 1024
self.depth = 24
####### Use Visualized BGE doing multi-modal knowledge retrieval import torch from FlagEmbedding.visual.modeling import Visualized_BGE bge_m3_mdl_path = '/data/BAAI/models--BAAI--bge-m3/snapshots/73a15ad29ab604f3bdc31601849a9defe86d563f' visual_mdl_path = '/data/BAAI/models--BAAI--bge-visualized/snapshots/98db10b10d22620010d06f11733346e1c98c34aa/Visualized_m3.pth' model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path) with torch.no_grad(): query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?") candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg") candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg") candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.") sim_1 = query_emb @ candi_emb_1.T sim_2 = query_emb @ candi_emb_2.T sim_3 = query_emb @ candi_emb_3.T print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])
运行代码如上。另外修改了一下FlagEmbedding/visual/modeling.py
from torch import nn, Tensor class Visualized_BGE(nn.Module): def __init__(self, model_name_bge: str = None, model_weight = None, # "/path/to/your/weight/file/" normlized: bool = True, sentence_pooling_method: str = 'cls', negatives_cross_device: bool = False, temperature: float = 0.02, # 1.0 ): super().__init__() # assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"] assert model_weight is not None name_flag = False for name in ["bge-base-en-v1.5", "bge-m3"]: if name in model_name_bge: name_flag = True if not name_flag: raise ValueError(f"model_name_bge should not be {name}") self.model_name_bge = model_name_bge # if model_name_bge == 'BAAI/bge-base-en-v1.5': # model_name_eva = "EVA02-CLIP-B-16" # self.hidden_dim = 768 # self.depth = 12 # elif model_name_bge == 'BAAI/bge-m3': # model_name_eva = "EVA02-CLIP-L-14" # self.hidden_dim = 1024 # self.depth = 24 if 'bge-base-en-v1.5' in model_name_bge: model_name_eva = "EVA02-CLIP-B-16" self.hidden_dim = 768 self.depth = 12 elif 'bge-m3' in model_name_bge: model_name_eva = "EVA02-CLIP-L-14" self.hidden_dim = 1024 self.depth = 24
请在model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path)
后面加上model.eval()
已解决,谢谢
使用的是bge-m3, candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")