CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.51k stars 471 forks source link

feat: Add support for DPO #556

Open sandeepchittilla opened 1 year ago

sandeepchittilla commented 1 year ago

Closes #504

This PR adds Direct Policy Optimization as introduced in https://arxiv.org/abs/2305.18290

Loss calculation and concatenated forward pass implementations are adapted from the original TRL library

sandeepchittilla commented 1 year ago

The WANDB job : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch

(updated the link to point to a run with a larger batch size)

sandeepchittilla commented 1 year ago

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

PhungVanDuy commented 1 year ago

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

sandeepchittilla commented 1 year ago

Thank you so much @PhungVanDuy for reviewing 🙏

Yes it's the same wandb run i shared above. Here you go : https://wandb.ai/sharma-sandeepch/trlx/runs/f7ym4m9y?workspace=user-sharma-sandeepch

LouisCastricato commented 1 year ago

Any update?

sandeepchittilla commented 1 year ago

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

PhungVanDuy commented 1 year ago

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

I saw your wandb but actually the chart quite messup, seems reward/accuracies and reward/margin not clearly increase. I guess because you used gpt2 instead of an SFT model on HH to do DPO. Can you use this SFT model and this preference dataset to train with this branch?

sandeepchittilla commented 1 year ago

@PhungVanDuy @maxreciprocate I cannot seem to request a review (probably due to permission issues / first contribution reasons). Could you please advise?

Thank you so much for the great PR. We will review this PR asap! Can you share any wandb that you ran?

@PhungVanDuy were you able to review this?

I saw your wandb but actually the chart quite messup, seems reward/accuracies and reward/margin not clearly increase. I guess because you used gpt2 instead of an SFT model on HH to do DPO. Can you use this SFT model and this preference dataset to train with this branch?

That's indeed what I did for a quick iteration and because I was limited on the compute i had. I will run it with the mistral-7b on the ultrafeedback dataset and get back asap.

sandeepchittilla commented 1 year ago

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch

Note :

PhungVanDuy commented 1 year ago

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch

Note :

  • Ultrafeedback is a challenging dataset for DPO because the rejected responses are randomly sampled
  • I have not done a sft pass on the data so we see some fluctuating plots.
  • I have limited memory and GPUs are not the best in class so I've chosen only a subset of test_prefs for evaluation

Thank you for your information, I will use SFT-beta, to check this. Let me help you to run on my cluster.

PhungVanDuy commented 1 year ago

@PhungVanDuy sorry for the delay, the gpus aren't always available. Here is a dpo run (ongoing) of 1 epoch with mistral-7b-sft-beta on the ultrafeedback_binarized dataset : https://wandb.ai/sharma-sandeepch/trlx/runs/kfpmeonf?workspace=user-sharma-sandeepch Note :

  • Ultrafeedback is a challenging dataset for DPO because the rejected responses are randomly sampled
  • I have not done a sft pass on the data so we see some fluctuating plots.
  • I have limited memory and GPUs are not the best in class so I've chosen only a subset of test_prefs for evaluation

Thank you for your information, I will use SFT-beta, to check this. Let me help you to run on my cluster.

@sandeepchittilla can you add my discord with the handle: duyphung.ai, it will be easier to discuss on this. Thank you so much.

StellaAthena commented 10 months ago

I'm excited about DPO support and I hope it'll be added soon!