Open SalmanMohammadi opened 3 weeks ago
Not sure if it will be useful for you, but there are 8-bit and 4-bit AdamW in torchao https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim. Both support FSDP1/2. 8-bit AdamW should match bnb and 4-bit version should match lpmm exactly. They are included in torchao 0.4 release, but there is a bug in handling LR schedule (fixed in main branch).
There are several optimizations to our PPO recipe which could help push it closer to SOTA in terms of performance. There are also several pieces of documentation we could offer alongside this recipe to increase visibility and improve accessibility. These are non-comprehensive and not all required.
Documentation
Optimizations
Rough benchmarks from deepspeed
I think the results from this page all use LoRA. Nonetheless, it's one of the only sources of compute useage for a modern RLHF implementation.
*It's unclear what size of reward model is used here. Throughout the blogpost they use reward model sizes << policy model sizes.
They also state:
Which gives 13.B for the combined memory of both actor + critic model on a single A100 80GB.
Compile issues @ebsmothers
[x] #1402
[ ] Enable compile for trajectory generation step?
[ ] Enable compile for loss step?
?? how else can we make inference go fast?
[ ] Reference +/ reward model offload to CPU @ebsmothers
[ ] Optimizer offload to CPU (#1351) (to benchmark once it lands)
(From deepspeed link above) - granted these aren't strictly performance opt:
[ ] Add LoRA PPO
[ ] MARL (https://huggingface.co/docs/trl/multi_adapter_rl) - single base model, multi-adapter PPO training
cc @kartikayk