Closed Santosh-Gupta closed 1 year ago
Hi @Santosh-Gupta
Thanks for raising this issue, indeed the fix you suggested (add num_labels=1
) should be the right fix, see a related discussion here: https://github.com/huggingface/trl/issues/558 would you be happy to open a PR to fix the example snippet in the README?
Thanks @younesbelkada , yup I would love to do the PR. Will do shortly.
@younesbelkada here it is
Awesome, thanks for the quick fix!
This might be a dumb question, but I am having trouble following how the readme example matches how reward modeling is described in the latest papers I've read on it.
From the readme, the example shows to AutoModelForSequenceClassification as the model for the RewardTrainer
https://github.com/huggingface/trl#rewardtrainer
The output for AutoModelForSequenceClassification is a tuple, the first item is a tensor of two logits x batch size (by default, one for the positive class, one for the negative class) and the second item is the hidden states. For convenience, here's a colab notebook demo with the code set up https://colab.research.google.com/drive/1VNx3QKhdSGbUIUrkzLi5_VGaLc_wrI-2?usp=sharing
From the RewardTrainer code, it takes the first item from this tuple, which would be a tensor of two values x batch size.
https://github.com/huggingface/trl/blob/84393f3b94b452adb71d28f42d52c10f0023a834/trl/trainer/reward_trainer.py#L167C1-L174C13
However, from the InstructGPT and Llama papers, I believe this should be a scalar.
This is from below equation 1 on page 8 https://arxiv.org/pdf/2203.02155.pdf
So for the readme example to match the reward modeling approaches in the instructGPT/Llama paper, I believe it should be
I am wondering if maybe for the example on the readme, AutoModelForSequenceClassification should have the number of classes specified to 1. So
So then rewards_chosen and rewards_rejected will be 1 x batch size.
Or perhaps I have a blindspot while thinking this through?
Though thinking this through theoretically, even if it doesn't match the implementations of the other papers, maybe this wouldn't negatively impact the training, it would just be optimizing for a pair of logits instead of one; perhaps it could be a form of regularization.