openpsi-project / ReaLHF

Super-Efficient RLHF Training of LLMs with Parameter Reallocation
Apache License 2.0
97 stars 5 forks source link

Suggestion for Fine-Grained Batch Control e.g `per_device_train_batch_size` or `mini_batch_size` #80

Open dechunwang opened 1 week ago

dechunwang commented 1 week ago

Hello there, First, I'd like to express my appreciation for your excellent work on this project. While experimenting with PPO/RW using this repository, I consistently encounter Out of Memory (OOM) errors with the following configuration:

This error is unexpected given the relatively small model size and parallelism configuration.

The project currently offers an n_mbs setting, which splits data batches into n_mbs chunks. However, this approach has limitations:

  1. Difficulty in determining the exact batch size after data packing
  2. To set n_mbs correctly without knowing globale batch size is very hard

After reviewing the documentation, I believe a crucial feature is missing: per_device_train_batch_size or mini_batch_size

A direct mini_batch_size setting would provide more intuitive and precise control over batch sizes across different parallelism configurations.

This setting would allow users to specify the mini-batch size for each Data Parallel (DP) rank, providing several benefits:

  1. Fine-grained control in MP/PP/TP scenarios (e.g., virtual parallel)
  2. Better resource utilization (e.g., a 7B parameter model shouldn't require PP > 2 if mini_batch_size = 1)

Would it be possible to consider adding this feature in a future update?

garrett4wade commented 14 hours ago

Hi dechun, sorry for the late reply.

Your suggestion is quite reasonable. I will raise a PR later. You can also make the change if you would like to contribute to this project.

I'd like to make some comments more about your scenario. Some tips here:

If OOM still happens, it would be helpful if you can share the log and CLI configuration. Discussions are welcomed.