Open hlnchen opened 1 week ago
I also met the same issue. I use the official example script, dpo_online.py, to train a 75b LLM with a 75b reward model. Even with 60x8 H100 GPUs, the problem still happens. Any help please?
Hello @hlnchen would you mind sharing a reproducible example that uses the unwrap_model_for_generation()
method in a simple training loop that simulates your application?
System Info
torch==2.4.0 transformers==4.43.4 trl==0.9.6 tokenizers==0.19.1 accelerate==0.32.0 peft==0.12.0 datasets==2.20.0 deepspeed==0.15.0 bitsandbytes==0.43.3 sentencepiece==0.2.0 flash-attn==2.6.3
gcc version 11.4.0 (Ubuntu 11.4.0-1ubuntu1~22.04)
Information
Tasks
examples
folderReproduction
Hi TRL team,
I am hitting OOM errors when fine-tuning a Llama-3.1-70B model on my modified RL trainer. It looks like the error happens on unwrapping the model for generation (I have an on policy algorithm and each training step I will generate some sequences)
My machine has 8 H100 80GB GPUs and I used lora. But it looks like
unwrap_model_for_generation
will load the entire model into memory and causing OOM. Any suggestions?Expected behavior
OOM issue resolved.