huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.86k stars 1.25k forks source link

Rewards actual_end index in PPOv2 trainer #1893

Open zhuyuzy opened 2 months ago

zhuyuzy commented 2 months ago

In PPOv2 trainer.train(), # 4. compute rewards, when computing the rewards index, the sequence_lengths_p1 is used.

actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)

It means we assign the reward one step after the stop_token_id, as far as I know, the reward is usually assigned to the final token which means the exact position where stop_token_id is. Did I miss or misunderstand something?

qgallouedec commented 2 months ago

Thanks for raising this question.

Actually, the name sequence_lengths is misleading, since it's actually the sequence length - 1:

https://github.com/huggingface/trl/blob/cbcaa46cd3c02c0e7f724b764c5848ae73796de7/trl/trainer/ppov2_trainer.py#L351

Consequently, if stop_token_id == 50277 and postprocessed_response == [[187, 187, 3, 50277, 13, 642]], sequence_lengths (and actual_end) will get the value 2 (and not 3).

zhuyuzy commented 2 months ago

Thanks for raising this question.

Actually, the name sequence_lengths is misleading, since it's actually the sequence length - 1:

https://github.com/huggingface/trl/blob/cbcaa46cd3c02c0e7f724b764c5848ae73796de7/trl/trainer/ppov2_trainer.py#L351

Consequently, if stop_token_id == 50277 and postprocessed_response == [[187, 187, 3, 50277, 13, 642]], sequence_lengths (and actual_end) will get the value 2 (and not 3).

Thank you for your reply.

However, it's not the same in my experiment. first_true_indices(postprocessed_response == tokenizer.pad_token_id) returns the first index where the pad_token_id is, in your case 4, then minus 1 we have the sequence_length, in your case 3 (not 2). And finally we have actual_end = sequence_lengths_p1 if stop_token_id exists, which is sequence_length plus 1 (in your case 4).

zhuyuzy commented 1 month ago

Thanks for raising this question. Actually, the name sequence_lengths is misleading, since it's actually the sequence length - 1: https://github.com/huggingface/trl/blob/cbcaa46cd3c02c0e7f724b764c5848ae73796de7/trl/trainer/ppov2_trainer.py#L351

Consequently, if stop_token_id == 50277 and postprocessed_response == [[187, 187, 3, 50277, 13, 642]], sequence_lengths (and actual_end) will get the value 2 (and not 3).

Thank you for your reply.

However, it's not the same in my experiment. first_true_indices(postprocessed_response == tokenizer.pad_token_id) returns the first index where the pad_token_id is, in your case 4, then minus 1 we have the sequence_length, in your case 3 (not 2). And finally we have actual_end = sequence_lengths_p1 if stop_token_id exists, which is sequence_length plus 1 (in your case 4).

@qgallouedec Hello, could you help clarify this? BTW, I’ve run the experiment with setting

actual_end = sequence_lengths

and found out the result is almost identical with

actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)

However, I’m still puzzled about why you chose to approach it this way.

Really appreciate it if you could provide an explanation. Thank you!

qgallouedec commented 2 weeks ago

Consequently, if stop_token_id == 50277 and postprocessed_response == [[187, 187, 3, 50277, 13, 642]], sequence_lengths (and actual_end) will get the value 2 (and not 3).

However, it's not the same in my experiment. first_true_indices(postprocessed_response == tokenizer.pad_token_id) returns the first index where the pad_token_id is, in your case 4

why 4?

zhuyuzy commented 2 weeks ago

Consequently, if stop_token_id == 50277 and postprocessed_response == [[187, 187, 3, 50277, 13, 642]], sequence_lengths (and actual_end) will get the value 2 (and not 3).

However, it's not the same in my experiment. first_true_indices(postprocessed_response == tokenizer.pad_token_id) returns the first index where the pad_token_id is, in your case 4

why 4?

Thank you for your reply.

The sequence_lengths is 3.

Because first_true_indices return the first index where pad_token_id is.

If stop_token_id == 50277, then after adding pad_token_id, postprocessed_response == [[187, 187, 3, 50277, 13, 13]] if pad_token_id == 13.

Then first_true_indices return where the first pad_token_id is, which is 4, and sequence_lengths get value 4-1=3.