JUNJIE99 / VISTA_Evaluation_FineTuning

Evaluation code and datasets for the ACL 2024 paper, VISTA: Visualized Text Embedding for Universal Multi-Modal Retrieval. The original code and model can be accessed at FlagEmbedding.
https://github.com/FlagOpen/FlagEmbedding/tree/master/research/visual_bge
22 stars 2 forks source link

Question about VISTA fine-tuning implementation. #8

Open kimwongyuda opened 1 month ago

kimwongyuda commented 1 month ago

I have two question for fine-tuning implementation.

  1. At VISTA_Evaluation_FineTuning/downstream_finetune_example/, when I run the code run_ds_cirr.py and load eva_clip weight from this part self.model_visual, self.preprocess_train, self.preprocess_val= create_eva_vision_and_transforms( model_name_eva, eva_pretrained_path, force_custom_clip=True) in modeling_ds_cirr.py, the missing_keys warning message occurs like below. [11-10-2024 14:13:40] INFO: incompatible_keys.missing_keys: ['visual.rope.freqs_cos', 'visual.rope.freqs_sin', 'visual.blocks.0.attn.rope.freqs_cos', 'visual.blocks.0.attn.rope.freqs_sin', 'visual.blocks.1.attn.rope.freqs_cos', 'visual.blocks.1.attn.rope.freqs_sin', 'visual.blocks.2.attn.rope.freqs_cos', 'visual.blocks.2.attn.rope.freqs_sin', 'visual.blocks.3.attn.rope.freqs_cos', 'visual.blocks.3.attn.rope.freqs_sin', 'visual.blocks.4.attn.rope.freqs_cos', 'visual.blocks.4.attn.rope.freqs_sin', 'visual.blocks.5.attn.rope.freqs_cos', 'visual.blocks.5.attn.rope.freqs_sin', 'visual.blocks.6.attn.rope.freqs_cos', 'visual.blocks.6.attn.rope.freqs_sin', 'visual.blocks.7.attn.rope.freqs_cos', 'visual.blocks.7.attn.rope.freqs_sin', 'visual.blocks.8.attn.rope.freqs_cos', 'visual.blocks.8.attn.rope.freqs_sin', 'visual.blocks.9.attn.rope.freqs_cos', 'visual.blocks.9.attn.rope.freqs_sin', 'visual.blocks.10.attn.rope.freqs_cos', 'visual.blocks.10.attn.rope.freqs_sin', 'visual.blocks.11.attn.rope.freqs_cos', 'visual.blocks.11.attn.rope.freqs_sin']

However, when I load eva_clip weight from VISTA_Evaluation_FineTuning/evaluation_example_webqa/BGE-base/modeling_evaluation_base.py, that message doesn't occur.

Why does the warning message occurs in the former unlike the latter?

If I fine-tune using pre-trained weight Visualized_base_en_v1.5.pth, I expect that the aforementioned issue doesn't matter because parameters of pre-trained weight fill all layers of model without gaps. Is my expectation right?

  1. In mm_encoder function in modeling_ds_cirr.py, prompt_embedding_output splits into cls_token and the rest of embedding, then final embedding is made by concatenating [cls_token, img_token_emb, the rest]. However, why did not prompt_attention mask split into like the seperation of prompt_embedding_output?

`

    cls_token = prompt_embedding_output[:, 0:1, :]
    prompt_embedding_output = prompt_embedding_output[:, 1:]

    prompt_img_embedding = torch.cat([cls_token, img_token_emb, prompt_embedding_output], dim=1)  # image-text sequence embedding

    img_attention_mask = torch.ones(batch_size, img_token_len, device=device)  
    prom_img_attention_mask = torch.cat([img_attention_mask, prompt_attention_mask], dim=1)

`

JUNJIE99 commented 1 month ago

For the first question, this message is not significant and does not affect the training or evaluation results because RoPE does not contain any learnable parameters.

For the second question, you can handle prom_img_attention_mask strictly according to prompt_embedding_output. However, the current method produces equivalent results. This is because the mask only needs to identify the padding tokens at the end, and padding only occurs in the text portion. The number of image tokens is always constant.