huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.56k stars 1.19k forks source link

Fix `start` index under `batched_forward_pass` #1782

Closed mertsayar8 closed 1 month ago

mertsayar8 commented 3 months ago

Start index under batched_forward_pass starts from the last query token which does not align with the comment in line 1032. This also causes a problem when handling the response tokens as stated in #1781.

Set start to the first response token instead of the last query token.

vwxyzjn commented 3 months ago

@mertsayar8 btw we are now recommending to use the PPOv2Trainer :)

HuggingFaceDocBuilderDev commented 3 months ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

qgallouedec commented 1 month ago

Hey, thank you for the PR, sorry for the late reply. I'm not sure about this. Let's consider the simple following example: The input is [1, 2, 3, 4, 5, 6, 7], and you predict [2, 3, 4, 5, 6, 7, 8]. So you want a mask like [0, 0, 1, 1, 1, 1, 0].

with your suggested modification:

import torch

query_batch = [torch.tensor([1, 2, 3])]
response_batch = [torch.tensor([4, 5, 6, 7])]
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1]])

# code from PPOTrainer
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]

for j in range(len(query_batch)):
    # ...
    start = len(query_batch[j])  # logprobs starts from the second query token
    if attention_mask[j, 0] == 0:  # offset left padding
        start += attention_mask[j, :].nonzero()[0]
    end = start + len(response_batch[j])

masks[j, :start] = 0
masks[j, end:] = 0

print(masks)  # tensor([[0, 0, 0, 1, 1, 1, 0]])

While we rather expect a mask tensor([[0, 0, 1, 1, 1, 1, 0]]).

Am I missing something?