allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.18k stars 191 forks source link

Reproducing IMDB results #28

Open mnoukhov opened 1 year ago

mnoukhov commented 1 year ago

Hi, I'm currently running the imdb experiments and trying to reproduce the PPO and NLPO results from the paper and though my PPO is close, NLPO is quite far from reported. Do you have any advice for reproducing NLPO results?

I'm running the default config (scripts/training/task_configs/imdb_text_continuation/gpt2_{ppo,nlpo}.yml) and the final test results compared to the results from the paper are below.

Sentiment Score Fluency (Perplexity)
zero-shot (ppo) 0.486 32.4
ppo 0.604 33.0
zero-shot (nlpo) 0.497 32.7
nlpo 0.496 40.8
paper's zero-shot 0.489 32.2
paper's ppo 0.605 33.5
paper's nlpo 0.637 32.7

PPO results are similar and even slightly lower ppl but NLPO is not at all close. Here are the validation curves image

NLPO also seems to improve in sentiment for a bit and then suddenly stops and decreases but all the while the perplexity is going up. Comparing the training curves, it seems that approx KL loss is much larger for NLPO but this could be reasonable given the changes in NLPO. Do you see similar curves?

image

Finally, in the paper's Appendix Table 4 it says that it runs for 10 epochs but in Figure 4 just below (and also based on the wandb logging) these experiments are for 50 epochs. Should I be running for 10 or 50 epochs?

Each experiment is being run on 4 A100 GPUs as per #12

rajcscw commented 1 year ago

@mnoukhov Hey, will get back to you after checking the configs. There could be config errors which made NLPO unstable.

rajcscw commented 1 year ago

@rajammanabrolu Can you double-check the NLPO config?

mnoukhov commented 1 year ago

Hey @rajammanabrolu have you had a chance to double-check the config and your results? Thanks

mnoukhov commented 1 year ago

Thanks to @rajcscw and @rajammanabrolu, a better NLPO config should have a learning rate of 1e-6 not 1e-5 and update iters 50

This makes NLPO essentially match the performance of PPO (it is still slightly worse). Hope this helps others working with NLPO

image