Closed liutianlin0121 closed 1 year ago
Nice PR! Please keep the normalize_samples
. We should calculate it as normalize_samples = int(args.local_normalize_samples * len(local_devices) * args.world_size)
keep the normalize_samples. We should calculate it as normalize_samples = int(args.local_normalize_samples len(local_devices) args.world_size)
I see! Done.
Per OAI's setting, I think local_rollout_batch_size = 512
. I also added the calculation.
Added two quick fixes. @liutianlin0121 let me know if they look good to you. If so feel free to merge.
Hey!
As discussed in our previous PR, we can use
pmap
to accelerate the reward model normalization step. This PR does that.Fig below shows that the run time comparison between the pmapped version (cyan) and the previous version (grey). Both versions use
normalize_before
andnormalize_after
. The normalization step of the pmapped version is faster.Test results are almost identical:
Additionally, I noted that in our previous pytorch and jax versions, we have an argument
args.normalize_samples
not used. We instead usedargs.local_normalize_samples
to effectively meanargs.normalize_samples
. In this PR, I replaced allargs.local_normalize_samples
byargs.normalize_samples
. But feel free to suggest edits ifargs.local_normalize_samples
is still needed.