huggingface / transformers

šŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.13k stars 27.05k forks source link

Precision issues in Mistral rotary embeddings #29496

Closed avnermay closed 5 months ago

avnermay commented 8 months ago

https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/mistral/modeling_mistral.py#L120-L121 https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/mistral/modeling_mistral.py#L377

If during mixed precision training (e.g., bf16 with HF trainer) of a Mistral model you pass an input equal to (or greater than) the model's maximum sequence length, it will generate new sin_cached and cos_cached tensors which will be incorrect due to precision issues. In particular, the inv_freq tensor will be in bf16 and this causes the issues. This causes large model quality issues, which I believe is what should be done here.

Other models and code bases deal with this by forcing the inv_freq tensor to be float32, which would be good to do here as well. It would also be a good idea to double check other models to make sure this precision problem does not happen for other models. https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L136-L147 https://github.com/Dao-AILab/flash-attention/blob/6c9e60de566800538fedad2ad5e6b7b55ca7f0c5/flash_attn/layers/rotary.py#L383-L392

ArthurZucker commented 8 months ago

Do you want to open a PR to propagate the changes we made to Llama and gemma?

ArthurZucker commented 8 months ago

cc @gante

danielhanchen commented 8 months ago

@avnermay I'm not too certain, but I think inv_freq will always be calculated in float32. For eg Gemma:

self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim))

And for Llama:

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras.

I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm inv_freq is actually bfloat16.

gante commented 8 months ago

@danielhanchen the inv_freq permanent buffer can be casted with .to model casting, e.g.

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype)

On Llama and Gemma that's no problem, since we're recently updated the code to cast inv_freq to float() before it is applied to get sin and cos (e.g. here). However, other RoPE models like Mistral have yet to receive the same treatment.

We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with torch.compile)

danielhanchen commented 8 months ago

@gante Whoops sorry just saw this - apologies!

Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over tensor.to itself

avnermay commented 8 months ago

Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision inv_freq tensor. This occurs (for example) during mixed precision training, because inv_freq was automatically downcast to bfloat16 in that case.

gante commented 8 months ago

@danielhanchen the only solution is to explicitly upcast šŸ˜¬ some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE).

@avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us šŸ¤—

avnermay commented 7 months ago

Just commenting on this so that it is not marked as stale. Thanks!

ArthurZucker commented 6 months ago

30642 will fix this ! šŸ¤—

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.