FlagOpen / FlagEmbedding

Retrieval and Retrieval-augmented LLMs
MIT License
7.53k stars 543 forks source link

bge-visual每次出来的embed值都不一样 #659

Closed charliedream1 closed 7 months ago

charliedream1 commented 7 months ago

使用的是bge-m3, candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")

JUNJIE99 commented 7 months ago

你好,请尝试在推理前添加model.eval()

charliedream1 commented 7 months ago

加了的

---原始邮件--- 发件人: @.> 发送时间: 2024年4月8日(周一) 晚上9:11 收件人: @.>; 抄送: "Optimus @.**@.>; 主题: Re: [FlagOpen/FlagEmbedding] bge-visual每次出来的embed值都不一样 (Issue #659)

你好,请尝试在推理前添加model.eval()

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

JUNJIE99 commented 7 months ago

你好,我这边测试每次结果都是一样的,能否提供您的运行代码?

这是我使用bge-m3的测试代码,您可以参考一下

import torch
from FlagEmbedding.visual.modeling import Visualized_BGE

model = Visualized_BGE(model_name_bge = "BAAI/bge-m3", model_weight="your path")
model.eval()
with torch.no_grad():
    query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
    candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
    candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
    candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")

sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6303]]) tensor([[0.3493]]) tensor([[0.5599]])
print(candi_emb_1) # tensor([[ 0.0120,  0.0005, -0.0165,  ..., -0.0078,  0.0148,  0.0139]])
charliedream1 commented 7 months ago
####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE

bge_m3_mdl_path = '/data/BAAI/models--BAAI--bge-m3/snapshots/73a15ad29ab604f3bdc31601849a9defe86d563f'
visual_mdl_path = '/data/BAAI/models--BAAI--bge-visualized/snapshots/98db10b10d22620010d06f11733346e1c98c34aa/Visualized_m3.pth'
model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path)

with torch.no_grad():
    query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
    candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
    candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
    candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")

sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])

运行代码如上。另外修改了一下FlagEmbedding/visual/modeling.py

from torch import nn, Tensor

class Visualized_BGE(nn.Module):
    def __init__(self,
                 model_name_bge: str = None,
                 model_weight = None, # "/path/to/your/weight/file/"
                 normlized: bool = True,
                 sentence_pooling_method: str = 'cls',
                 negatives_cross_device: bool = False,
                 temperature: float = 0.02, # 1.0
                 ):
        super().__init__()
        # assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"]
        assert model_weight is not None
        name_flag = False
        for name in ["bge-base-en-v1.5", "bge-m3"]:
            if name in model_name_bge:
                name_flag = True

        if not name_flag:
            raise ValueError(f"model_name_bge should not be {name}")

        self.model_name_bge = model_name_bge

        # if model_name_bge == 'BAAI/bge-base-en-v1.5':
        #     model_name_eva = "EVA02-CLIP-B-16"
        #     self.hidden_dim = 768
        #     self.depth = 12
        # elif model_name_bge == 'BAAI/bge-m3':
        #     model_name_eva = "EVA02-CLIP-L-14"
        #     self.hidden_dim = 1024
        #     self.depth = 24

        if 'bge-base-en-v1.5' in model_name_bge:
            model_name_eva = "EVA02-CLIP-B-16"
            self.hidden_dim = 768
            self.depth = 12
        elif 'bge-m3' in model_name_bge:
            model_name_eva = "EVA02-CLIP-L-14"
            self.hidden_dim = 1024
            self.depth = 24
JUNJIE99 commented 7 months ago
####### Use Visualized BGE doing multi-modal knowledge retrieval
import torch
from FlagEmbedding.visual.modeling import Visualized_BGE

bge_m3_mdl_path = '/data/BAAI/models--BAAI--bge-m3/snapshots/73a15ad29ab604f3bdc31601849a9defe86d563f'
visual_mdl_path = '/data/BAAI/models--BAAI--bge-visualized/snapshots/98db10b10d22620010d06f11733346e1c98c34aa/Visualized_m3.pth'
model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path)

with torch.no_grad():
    query_emb = model.encode(text="Are there sidewalks on both sides of the Mid-Hudson Bridge?")
    candi_emb_1 = model.encode(text="The Mid-Hudson Bridge, spanning the Hudson River between Poughkeepsie and Highland.", image="./imgs/wiki_candi_1.jpg")
    candi_emb_2 = model.encode(text="Golden_Gate_Bridge", image="./imgs/wiki_candi_2.jpg")
    candi_emb_3 = model.encode(text="The Mid-Hudson Bridge was designated as a New York State Historic Civil Engineering Landmark by the American Society of Civil Engineers in 1983. The bridge was renamed the \"Franklin Delano Roosevelt Mid-Hudson Bridge\" in 1994.")

sim_1 = query_emb @ candi_emb_1.T
sim_2 = query_emb @ candi_emb_2.T
sim_3 = query_emb @ candi_emb_3.T
print(sim_1, sim_2, sim_3) # tensor([[0.6932]]) tensor([[0.4441]]) tensor([[0.6415]])

运行代码如上。另外修改了一下FlagEmbedding/visual/modeling.py

from torch import nn, Tensor

class Visualized_BGE(nn.Module):
    def __init__(self,
                 model_name_bge: str = None,
                 model_weight = None, # "/path/to/your/weight/file/"
                 normlized: bool = True,
                 sentence_pooling_method: str = 'cls',
                 negatives_cross_device: bool = False,
                 temperature: float = 0.02, # 1.0
                 ):
        super().__init__()
        # assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"]
        assert model_weight is not None
        name_flag = False
        for name in ["bge-base-en-v1.5", "bge-m3"]:
            if name in model_name_bge:
                name_flag = True

        if not name_flag:
            raise ValueError(f"model_name_bge should not be {name}")

        self.model_name_bge = model_name_bge

        # if model_name_bge == 'BAAI/bge-base-en-v1.5':
        #     model_name_eva = "EVA02-CLIP-B-16"
        #     self.hidden_dim = 768
        #     self.depth = 12
        # elif model_name_bge == 'BAAI/bge-m3':
        #     model_name_eva = "EVA02-CLIP-L-14"
        #     self.hidden_dim = 1024
        #     self.depth = 24

        if 'bge-base-en-v1.5' in model_name_bge:
            model_name_eva = "EVA02-CLIP-B-16"
            self.hidden_dim = 768
            self.depth = 12
        elif 'bge-m3' in model_name_bge:
            model_name_eva = "EVA02-CLIP-L-14"
            self.hidden_dim = 1024
            self.depth = 24

请在model = Visualized_BGE(model_name_bge=bge_m3_mdl_path, model_weight=visual_mdl_path) 后面加上model.eval()

charliedream1 commented 7 months ago

已解决,谢谢