huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.31k stars 1.17k forks source link

SFTTrainer encounters error with OPT finetuning (int8 + LoRA) #1109

Closed chenmoneygithub closed 7 months ago

chenmoneygithub commented 9 months ago

Hi team,

I am encountering an error when finetuning OPT model with SFTTrainer:

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != float

It turned out to be caused by quantization, since removing load_in_8bit=True works well:

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m",
    load_in_8bit=True,
    device_map={"": Accelerator().process_index},
)

Please find the reproducible code in this colab: link

I am trying to adapt the stack_llama example into a single colab, which can run within a single GPU so that people can understand the API better. While doing it, I encountered the problem above. Could anyone help take a look? Thanks!

lvwerra commented 9 months ago

Maybe @younesbelkada knows more about this.

younesbelkada commented 8 months ago

Hi @chenmoneygithub Thanks for the issue! There are some issues with bnb-8bit and v100s. Can you try to wrap the training with torch autocast context manager? Alternatively I suggest to go for 4-bit as it much more stable for Lora training. Simply :

model = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m",
-   load_in_8bit=True,
+   load_in_4bit=True,
    device_map={"": Accelerator().process_index},
)
chenmoneygithub commented 8 months ago

@younesbelkada Thanks for the information! load_in_4bit works magically.

Can you try to wrap the training with torch autocast context manager?

I tried it, while it bypassed the mixed precision crashing issue, the loss became 0 from training starts.

younesbelkada commented 8 months ago

Thanks @chenmoneygithub ! Does the training works normally without the autocast context manager?

chenmoneygithub commented 8 months ago

@younesbelkada Yes, using load_in_4bit without autocast context manager worked well, thanks for your help!

BTW, when I was using LoRA + int8 model to train the reward model via RewardTrainer, I hit loss=0 problem. I'm wondering if there is any magic here on training the reward model? I'm stuck here since without a decent reward model, I cannot move forward to writing the RLHF part. I'm also curious how do you think about DPO vs PPO?

Lastly, thank you all so much for creating and maintaining the TRL package, it's a huge pain to do practical LLM training with vanilla PyTorch, and TRL is exactly what I'm looking for.

younesbelkada commented 8 months ago

Thanks a lot for your message @chenmoneygithub ! Glad that you like using TRL for your projects :D

when I was using LoRA + int8 model to train the reward model via RewardTrainer, I hit loss=0 problem.

Not sure if this is related to int8, but I advise users to strongly use 4bit now instead of 8-bit, as the 8-bit layers has not been designed to be used for LoRA at first place, whereas the 4bit layers were specifically designed for fine-tuning. So if the issue disappears with 4-bit, it might be an issue with it.

To mitigate the RM issue, I strongly also advise using DPO as you remove the need of having to use a RM. Although you need a paired dataset of prompt preferences such as ultra-feedback dataset. You can follow what has been done for Zephyr that greatly benefits from DPO, you can refer to the alignment handbook: https://github.com/huggingface/alignment-handbook

@kashif from the team is also working on KTO trainer, which is an extension of DPO that makes it easier to use it, instead of requiring a paired dataset of accepted and rejected prompts, with KTO you only need a label 'selected' or 'rejected' for each prompt, making it easier to crowd-source such a dataset (https://github.com/huggingface/trl/pull/1181/)

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.