pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.35k stars 440 forks source link

Add NLL and DPOP weighting to DPO losses #2032

Open RdoubleA opened 2 days ago

RdoubleA commented 2 days ago

Several modifications to the DPO loss function have shown to improve DPO model quality. These include adding a weighted negative log-likelihood loss (https://arxiv.org/pdf/2404.19733) and a DPO-positive loss (https://arxiv.org/pdf/2402.13228)

We could implement these as additional float parameters in the DPO loss functions to added a weighted loss term, or as separate loss functions that you would combine. We should first assess how impactful these papers have been.

cc @SalmanMohammadi