huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.88k stars 27.2k forks source link

Out-of-Index Error when training by `Qwen2VLFlashAttention2` #33302

Closed heli-qi closed 3 months ago

heli-qi commented 3 months ago

System Info

Who can help?

@ArthurZucker @amyeroberts

Information

Tasks

Reproduction

Hi, I'm finetuning the newly-released Qwen2VLForConditionalGeneration model by LoRA. I'm building the model by

Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct", attn_implementation="flash_attention_2", torch_dtype=torch.float16
)

I found attn_implementation="flash_attention_2" activates Qwen2VLFlashAttention2 which will throw a out-of-index error during training. When I switch to attn_implementation="sdpa", the error does not come out and training goes smoothly.

After some time of debugging, I located that the problem comes from this line where rotary_seq_len does not properly reflect the length of the input sequence but rather the real length minus 1. I modified this line to rotary_seq_len = cache_position[-1] + 1 in my local transformers offline package and it turns out that the training with flash_attention_2 goes smoothly.

My input batch to the model is as follow:

batch
    input_ids: Tensor (B, seq_len)
    attention_mask: Tensor (B, seq_len)
    labels: Tensor (B, seq_len)
    pixel_values: Tensor (B, res_h, res_w)  # res_h and res_w are the shape of image after processor()
    image_grid_thw: Tensor (B, 3)

I suspect that my input batch to the model has the correct shape, so I'm wondering whether my tiny workaround is the optimal solution to the problem. I really appreciate it if you could tell me some better solutions.

Expected behavior

As Reproduction section. Thanks for your patience for my issue.

zucchini-nlp commented 3 months ago

Will be fixed by https://github.com/huggingface/transformers/pull/33161 🤗