Closed psinger closed 3 months ago
cc @ArthurZucker @younesbelkada
I think this is heavily related to #29496, and should simply need an update in the ROPE precision! Would you like to open a PR to update this and see if that solved your issues ?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I have been seeing very weird behavior when training and running Mistral or Mixtral with samples being exactly the length of
max_position_embeddings
. The strange behavior manifested itself with complete broken outputs that interestingly resolved itself after reloading the model and running samples with shorter length through.So the following combination always broke: Model with
max_position_embeddings=8192
and using FA2 and using some samples with sizemax_length=8192
. It was resolved by either disabling FA2, or actually using samples withmax_length=8191
.After a lot of debugging, I figured out that this issue only happens with Flash Attention 2 and not with SDPA or vanilla attention.
I am suspecting that this issue stems from the following line: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L447
If we have a batch with a sequence length of let's say
8192
, which could be the same asmax_position_embeddings
, thenkv_seq_len
will be8192
which is the max here, but then we are adding1
, which will lead to8193
and then we are callingrotary_emb
with it.There, we then call: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py#L214 and thus re-init the cache with a longer than supported max sequence length.
I think it can be already solved by changing it to:
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
I actually noticed that this code has been changed very recently for Mistral to not take the max length and reset it anylonger: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L108
This was done in PR https://github.com/huggingface/transformers/pull/30642
I think this might have been just a side effect and does not fix Mixtral behavior.