huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.73k stars 26.22k forks source link

llama 2 weights from fb (in bfloat16) are perhaps accidentally cast to float16 in conversion script? #25446

Closed jmhessel closed 1 year ago

jmhessel commented 1 year ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Hi there!

credit to @dirkgr and @jacob-morrison for finding this LoC !

The facebook -> huggingface conversion script for llama/llama2 appears to cast weights to float16. While llama1 was distributed in fp16:

loaded_llama1 = torch.load("llama1/7B/consolidated.00.pth", map_location="cuda:0")
loaded_llama1['layers.4.feed_forward.w2.weight'].dtype
torch.float16

llama2 seems to be bfloat16

loaded_llama2 = torch.load("Llama-2-7b/consolidated.00.pth", map_location="cuda:0")
loaded_llama2['layers.4.feed_forward.w2.weight'].dtype
torch.bfloat16

The casting differences are small in both absolute and percent terms (here's a random weight)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

facebook_model = torch.load("Llama-2-7b/consolidated.00.pth", map_location="cuda:0")
huggingface_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
from_huggingface = model.get_parameter('model.layers.19.self_attn.v_proj.weight').cuda()
from_facebook = loaded['layers.19.attention.wv.weight']
...
print(torch.mean((torch.abs(from_facebook - from_huggingface) / torch.abs(from_facebook))*100))
# > tensor(0.0006, device='cuda:0', dtype=torch.float32, grad_fn=<MeanBackward0>)
print((from_facebook - from_huggingface).abs().max())
# > tensor(2.9802e-08, device='cuda:0', dtype=torch.float32, grad_fn=<MaxBackward1>)

but, in theory, this could lead to a small performance degradation (e.g., https://github.com/facebookresearch/llama/issues/634).

Expected behavior

I think llama2 should probably be saved in bfloat16 rather than cast to float16 and then saved in huggingface format.

ArthurZucker commented 1 year ago

Hey! The weights were pushed as float16 as this is what is used for inference. We are going to add a line mentioning that the training was done un bfloat16 but there should not be issues with performances when training no. There is an issue with training in float16 as it was reported here #25065, which is expected. Not by that default, if you use LlamaXXXX the dtype will be torch's default float32.

jmhessel commented 1 year ago

Hey @ArthurZucker ! Thanks for the reply :-) When you say "The weights were pushed as float16 as this is what is used for inference" --- is this what meta used in the llama2 paper for their results? I guess I am wondering why not also do bfloat16 for inference, particularly because of the potential fine-tuning issue when fine-tuning in float16 ? I could probably just do the conversion myself and host them on my huggingface, but just wondering the rationale, if possible

ArthurZucker commented 1 year ago

If you look at this or just try running a model you'll see that it is in fp16. Main reason is that it's faster and should not really induce performance loss. But training was done in bf16.

jmhessel commented 1 year ago

Aha! Gotcha :-) seems like if the official implementation does it this way, good to have it this way in huggingface. Maybe I'll do a bfloat16 conversion on my own and do some experiments, but probably not a big deal either way. Thanks!

jmhessel commented 1 year ago

FYI, for future readers, something related to this bfloat conversion was made: https://github.com/huggingface/transformers/commit/015f8e110d270a0ad42de4ae5b98198d69eb1964#diff-110a445233a8b15a0875998eeaf75cb8607b38a5daa736291dd058766879bbddL259-R273

It isn't clear to me what this change actually does, but it looks like the codellama weights are in bfloat16 https://huggingface.co/codellama/CodeLlama-7b-hf

not sure if this was intended but it might be worth trying the conversion for the original models in bfloat16 :-) (which I might try)

ArthurZucker commented 1 year ago

Once and for all, the dtype of the checkpoints on the hub is only used if you set torch_dtype = "auto" when you initialise the checkpoints. Otherwise, the torch_dtype will be used to cast the checkpoints from the initialization type (so torch's float32) to this torch_dtype (only when you are using the auto API. The reason why we used the torch_dtype = torch.floa16 is because that the inference dtype, and thus for most common usages were you just want something to work out of the box, the type that should be used.

ZhaofengWu commented 11 months ago

Following up on this, so what is the recommended dtype for llama2 inference? I assumed it's torch.float16, given this thread and also I've always been working with the assumption that the dtype in config.json is the recommendation. However, (1) I saw NaN issues inferencing with torch.float16, which went away after switching to torch.bfloat16; (2) the config for codellama specifies torch.bfloat16, as Jack pointed out above.

fxmarty commented 4 months ago

@ArthurZucker Meta officially recommends bf16 though https://github.com/meta-llama/llama3?tab=readme-ov-file#access-to-hugging-face.

It is not obvious to me that casting bf16 -> fp16 -> bf16 does not lead to degraded perf.

ArthurZucker commented 4 months ago

It does lead to degraded perf. we pushed bfloat16 and the script respects that now