Closed avnermay closed 5 months ago
Do you want to open a PR to propagate the changes we made to Llama and gemma?
cc @gante
@avnermay I'm not too certain, but I think inv_freq
will always be calculated in float32. For eg Gemma:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim))
And for Llama:
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
The downcast only applies to matrix multiplications and explicit downcasts like what I found what they did in Keras.
I haven't ran the code to confirm, but it would be great if you can print the dtype during a finetuning run to confirm inv_freq
is actually bfloat16.
@danielhanchen the inv_freq
permanent buffer can be casted with .to
model casting, e.g.
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = model.to(device="cuda", dtype=torch.bfloat16)
print(model.model.layers[0].self_attn.rotary_emb.inv_freq.dtype)
On Llama and Gemma that's no problem, since we're recently updated the code to cast inv_freq
to float()
before it is applied to get sin
and cos
(e.g. here). However, other RoPE models like Mistral have yet to receive the same treatment.
We'll gladly take PRs to fix it ;) We will be touching the other RoPE models soon anyways, to migrate them to a Llama-like structure (which, contrarily to other models, is compatible with torch.compile
)
@gante Whoops sorry just saw this - apologies!
Oh fair points on this! Hmm is there like some sort of lockin mechanism to not allow the conversion to occur? Maybe some sort of overriding mechanism ie write over tensor.to
itself
Why not use the approach taken by the other models, that force inv_freq to be float32? The key is avoiding cases where cos and sin are recomputed using a low-precision inv_freq
tensor. This occurs (for example) during mixed precision training, because inv_freq was automatically downcast to bfloat16 in that case.
@danielhanchen the only solution is to explicitly upcast š¬ some frameworks like deepspeed explicitly can hijack tensor creation and force them to be initialized in a certain type (which has also caused issues with RoPE).
@avnermay that is the solution. The change is simple, but we are working on other overlapping problems -- bear with us š¤
Just commenting on this so that it is not marked as stale. Thanks!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/mistral/modeling_mistral.py#L120-L121 https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/mistral/modeling_mistral.py#L377
If during mixed precision training (e.g., bf16 with HF trainer) of a Mistral model you pass an input equal to (or greater than) the model's maximum sequence length, it will generate new
sin_cached
andcos_cached
tensors which will be incorrect due to precision issues. In particular, theinv_freq
tensor will be in bf16 and this causes the issues. This causes large model quality issues, which I believe is what should be done here.Other models and code bases deal with this by forcing the inv_freq tensor to be float32, which would be good to do here as well. It would also be a good idea to double check other models to make sure this precision problem does not happen for other models. https://github.com/huggingface/transformers/blob/965cf677695dd363285831afca8cf479cf0c600c/src/transformers/models/llama/modeling_llama.py#L136-L147 https://github.com/Dao-AILab/flash-attention/blob/6c9e60de566800538fedad2ad5e6b7b55ca7f0c5/flash_attn/layers/rotary.py#L383-L392