Open dechunwang opened 1 week ago
Hi dechun, sorry for the late reply.
Your suggestion is quite reasonable. I will raise a PR later. You can also make the change if you would like to contribute to this project.
I'd like to make some comments more about your scenario. Some tips here:
Set "ppo.gen.force_no_logits_mask=True", otherwise the mask will occupy a huge amount of GPU memory but turns out to have minor learning performance improvement.
Use the identical parallel strategy if your resource budget is tight (i.e., allocation_mode="d1m8p1", I believe TP is more memory-efficient than PP though). This will disable parameter reallocation but still enable offloading.
For now, calculate n_mbs
based on per_device_train_batch_size=1
. For PPO training, the global batch size is per_device_train_batch_size * (pp_size * 2) * dp_size * ppo_n_mbs * n_mbs
. For generation or inference, it is per_device_train_batch_size * pp_size * dp_size * n_mbs
. Though inconvenient, I think setting per_device_train_batch_size=1
implicitly will fix the OOM issue.
If OOM still happens, it would be helpful if you can share the log and CLI configuration. Discussions are welcomed.
Hello there, First, I'd like to express my appreciation for your excellent work on this project. While experimenting with PPO/RW using this repository, I consistently encounter Out of Memory (OOM) errors with the following configuration:
This error is unexpected given the relatively small model size and parallelism configuration.
The project currently offers an
n_mbs
setting, which splits data batches inton_mbs
chunks. However, this approach has limitations:n_mbs
correctly without knowing globale batch size is very hardAfter reviewing the documentation, I believe a crucial feature is missing:
per_device_train_batch_size
ormini_batch_size
A direct
mini_batch_size
setting would provide more intuitive and precise control over batch sizes across different parallelism configurations.This setting would allow users to specify the mini-batch size for each Data Parallel (DP) rank, providing several benefits:
mini_batch_size = 1
)Would it be possible to consider adding this feature in a future update?