vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

use pmap to normalize reward model #17

Closed liutianlin0121 closed 1 year ago

liutianlin0121 commented 1 year ago

Hey!

As discussed in our previous PR, we can use pmap to accelerate the reward model normalization step. This PR does that.

Fig below shows that the run time comparison between the pmapped version (cyan) and the previous version (grey). Both versions use normalize_before and normalize_after. The normalization step of the pmapped version is faster.

Screenshot from 2023-09-02 11-41-21

Test results are almost identical: Screenshot from 2023-09-02 11-43-54

Additionally, I noted that in our previous pytorch and jax versions, we have an argument args.normalize_samples not used. We instead used args.local_normalize_samples to effectively mean args.normalize_samples. In this PR, I replaced all args.local_normalize_samples by args.normalize_samples. But feel free to suggest edits if args.local_normalize_samples is still needed.

vwxyzjn commented 1 year ago

Nice PR! Please keep the normalize_samples. We should calculate it as normalize_samples = int(args.local_normalize_samples * len(local_devices) * args.world_size)

liutianlin0121 commented 1 year ago

keep the normalize_samples. We should calculate it as normalize_samples = int(args.local_normalize_samples len(local_devices) args.world_size)

I see! Done.

vwxyzjn commented 1 year ago

Per OAI's setting, I think local_rollout_batch_size = 512. I also added the calculation.

vwxyzjn commented 1 year ago

Added two quick fixes. @liutianlin0121 let me know if they look good to you. If so feel free to merge.