huggingface / alignment-handbook

Robust recipes to align language models with human and AI preferences
https://huggingface.co/HuggingFaceH4
Apache License 2.0
4.53k stars 393 forks source link

Question about torch_dtype when runnging run_orpo.py #174

Closed sylee96 closed 2 months ago

sylee96 commented 3 months ago

I have been using run_orpo.py with my personal data successfully. However, as I use it, I have a question.

When I look at the code for run_orpo.py, I see that there is a code to match torch_dtype to the dtype of the pretrained model. However, when I actually train and save the model, even if the pretrained model's dtype was bf16, it gets changed to fp32. Why is this happening?

alvarobartt commented 2 months ago

Hi here! Not sure if that's related to https://github.com/huggingface/alignment-handbook/issues/175 at all, but feel free to upgrade the trl version and re-run as mentioned in that issue 🤗

Other than that, could you share the configuration you're using so that we can reproduce and debug that issue further? Thanks in advance!

sylee96 commented 2 months ago

Hi alvarobartt,

Here are the details of the environment and configuration I used:

run_orpo.py configuration: torch_dtype setting: bf16

With this setup, the dtype of the model changes to fp32 when saving the model, even though it was set to bf16. Please let me know if you need any additional information.

Thanks!

alvarobartt commented 2 months ago

Thanks for that @sylee96, to better understand the problem here is that the training is indeed happening in bfloat16, but save_pretrained is storing the weights in float32 instead? How did you checked that? Could you share the command you're running as python run_orpo.py ... or accelerate launch ... run_orpo.py ... to try to reproduce on our end? Thanks again 🤗

sylee96 commented 2 months ago

Thanks for answering, @alvarobartt.

I use this command line like this.

ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch --config_file orpo/configs/fsdp.yaml orpo/run_orpo.py orpo/configs/config_full.yaml

When I checked the gemma2, llama3, or qwen2 model dtype before training, the model's dtype was set to bfloat16. But when I checked dtype of models after training and saving, I detected the dtype of model was changed to float32.

When I checked the model's dtype, I used this line.

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
print(model.dtype)
alvarobartt commented 2 months ago

Thanks for answering, @alvarobartt.

Anytime @sylee96!

When I checked the model's dtype, I used this line.

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
print(model.dtype)

You should load it as follows i.e. specifying the torch.dtype to use when loading the model, otherwise torch.float32 is used by default.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16)
print(model.dtype)

Hope that helps! 🤗

sylee96 commented 2 months ago

Thanks for your help, @alvarobartt!

I would close this issue.