huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.33k stars 26.86k forks source link

FlashAttention2 issue with Mistral/Mixtral related to max length and RotaryEmbedding #31228

Closed psinger closed 3 months ago

psinger commented 5 months ago

I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of max_position_embeddings. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.

So the following combination always broke: Model with max_position_embeddings=8192 and using FA2 and using some samples with size max_length=8192. It was resolved by either disabling FA2, or actually using samples with max_length=8191.

After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.

I am suspecting that this issue stems from the following line: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447

rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)

If we have a batch with a sequence length of let's say 8192, which could be the same as max_position_embeddings, then kv_seq_len will be 8192 which is the max here, but then we are adding 1, which will lead to 8193 and then we are calling rotary_emb with it.

There, we then call: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214 and thus re-init the cache with a longer than supported max sequence length.

I think it can be already solved by changing it to: rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)

I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108

This was done in PR https://github.com/huggingface/transformers/pull/30642

I think this might have been just a side effect and does not fix Mixtral behavior.

amyeroberts commented 5 months ago

cc @ArthurZucker @younesbelkada

ArthurZucker commented 4 months ago

I think this is heavily related to #29496, and should simply need an update in the ROPE precision! Would you like to open a PR to update this and see if that solved your issues ?

github-actions[bot] commented 4 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.