huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10k stars 1.27k forks source link

In the reward trainer, `rewards_chosen` and `rewards_rejected` are tensor of two values x batch size if using AutoModelForSequenceClassification (as followed from the Readme). Shouldn't they be a single scaler x batch size? #653

Closed Santosh-Gupta closed 1 year ago

Santosh-Gupta commented 1 year ago

This might be a dumb question, but I am having trouble following how the readme example matches how reward modeling is described in the latest papers I've read on it.

From the readme, the example shows to AutoModelForSequenceClassification as the model for the RewardTrainer

https://github.com/huggingface/trl#rewardtrainer

model = AutoModelForSequenceClassification.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...
# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

The output for AutoModelForSequenceClassification is a tuple, the first item is a tensor of two logits x batch size (by default, one for the positive class, one for the negative class) and the second item is the hidden states. For convenience, here's a colab notebook demo with the code set up https://colab.research.google.com/drive/1VNx3QKhdSGbUIUrkzLi5_VGaLc_wrI-2?usp=sharing

From the RewardTrainer code, it takes the first item from this tuple, which would be a tensor of two values x batch size.

        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
        )[0]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
        )[0]

https://github.com/huggingface/trl/blob/84393f3b94b452adb71d28f42d52c10f0023a834/trl/trainer/reward_trainer.py#L167C1-L174C13

However, from the InstructGPT and Llama papers, I believe this should be a scalar.

1)where rθ(x, y) is the scalar output of the reward model for prompt x and completion y with parameters θ, yw is the preferred completion out of the pair of yw and yl, and D is the dataset of human comparison

This is from below equation 1 on page 8 https://arxiv.org/pdf/2203.02155.pdf

So for the readme example to match the reward modeling approaches in the instructGPT/Llama paper, I believe it should be

        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
        )[0][:,0]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
        )[0][:,0]

I am wondering if maybe for the example on the readme, AutoModelForSequenceClassification should have the number of classes specified to 1. So

config2 = AutoConfig.from_pretrained("gpt2", num_labels=1)
model2 = AutoModelForSequenceClassification.from_pretrained("gpt2", config=config2)

So then rewards_chosen and rewards_rejected will be 1 x batch size.

Or perhaps I have a blindspot while thinking this through?

Though thinking this through theoretically, even if it doesn't match the implementations of the other papers, maybe this wouldn't negatively impact the training, it would just be optimizing for a pair of logits instead of one; perhaps it could be a form of regularization.

younesbelkada commented 1 year ago

Hi @Santosh-Gupta Thanks for raising this issue, indeed the fix you suggested (add num_labels=1) should be the right fix, see a related discussion here: https://github.com/huggingface/trl/issues/558 would you be happy to open a PR to fix the example snippet in the README?

Santosh-Gupta commented 1 year ago

Thanks @younesbelkada , yup I would love to do the PR. Will do shortly.

Santosh-Gupta commented 1 year ago

@younesbelkada here it is

https://github.com/huggingface/trl/pull/657

younesbelkada commented 1 year ago

Awesome, thanks for the quick fix!