pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.93k stars 354 forks source link

[RFC] RLHF follow-ups #1395

Open SalmanMohammadi opened 3 weeks ago

SalmanMohammadi commented 3 weeks ago

There are several optimizations to our PPO recipe which could help push it closer to SOTA in terms of performance. There are also several pieces of documentation we could offer alongside this recipe to increase visibility and improve accessibility. These are non-comprehensive and not all required.

Documentation


Optimizations

Rough benchmarks from deepspeed

I think the results from this page all use LoRA. Nonetheless, it's one of the only sources of compute useage for a modern RLHF implementation.

image

*It's unclear what size of reward model is used here. Throughout the blogpost they use reward model sizes << policy model sizes.

image

They also state:

For now, we suggest that users use "Total-GPU-Memory-in-GB / 6" as the upper parameter bound in billions for the sum of the actor model and critical model, for safety. Nevertheless, users are welcome to try the real limit.

Which gives 13.B for the combined memory of both actor + critic model on a single A100 80GB.


Compile issues @ebsmothers


cc @kartikayk

gau-nernst commented 2 weeks ago

Not sure if it will be useful for you, but there are 8-bit and 4-bit AdamW in torchao https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim. Both support FSDP1/2. 8-bit AdamW should match bnb and 4-bit version should match lpmm exactly. They are included in torchao 0.4 release, but there is a bug in handling LR schedule (fixed in main branch).