llava-rlhf / LLaVA-RLHF

Aligning LMMs with Factually Augmented RLHF
https://llava-rlhf.github.io/
GNU General Public License v3.0
315 stars 21 forks source link

Training on RTX 4090 #18

Closed luohaowen2003 closed 11 months ago

luohaowen2003 commented 11 months ago

Thank you for your awesome work!

I wonder if it's possible to train the reward model and initialize the policy model on a node with 8 RXT 4090 GPU?

I found that the PPO process can be run with train_rl_model.sh under such setting, but when training the RM with train_reward_model.sh or initializing the policy model with initialize_policy_model.sh, I will run out of memory, which seems weird to me as the PPO process is presumably more demanding.

I've tried to minimize the per_device_train_batch_size to 1 but still run into such issue.

Thank you!

Edward-Sun commented 11 months ago

Hi, This is because we accidentally used QLoRA (4-bit base weights) during PPO but used 16-bit for sft/rm. You might be able to train the RM or initialize the policy model with 4-bit as well.

luohaowen2003 commented 11 months ago

Thanks!