[ ] An officially supported task in the examples folder
[ ] My own task or dataset (give details below)
Reproduction
In the implementation of PPO trainer, the logits of the policy-generated sentences are inherited from outputs of the model's geneate(), i.e., output.scores. However, the returned logit scores are processed or filtered with the transformers's LogitsProcessor (e.g., Temperature, Top-k, Top-p), meaning they are not the original logits from model.forward().
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
query_responses, logitss = batch_generation(
unwrapped_model.policy,
queries,
args.local_rollout_forward_batch_size,
processing_class.pad_token_id,
generation_config,
)
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
query = queries[i : i + args.local_rollout_forward_batch_size]
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
response = query_response[:, context_length:]
logits = logitss[i : i + args.local_rollout_forward_batch_size]
all_logprob = F.log_softmax(logits, dim=-1)
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
del logits, all_logprob
torch.cuda.empty_cache()
Expected behavior
I think this can have an impact on ppo training as it may ignore some tokens's probility. And the possible solution is that we should use the policy/ref model to forward again to recompute the logits over the sentence.
Checklist
[X] I have checked that my issue isn't already filed (see open issues)
[X] I have included my system information
[X] Any code provided is minimal, complete, and reproducible (more on MREs)
[X] Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
System Info
trl==0.12.0
Information
Tasks
examples
folderReproduction
In the implementation of PPO trainer, the logits of the policy-generated sentences are inherited from outputs of the model's geneate(), i.e., output.scores. However, the returned logit scores are processed or filtered with the transformers's LogitsProcessor (e.g., Temperature, Top-k, Top-p), meaning they are not the original logits from model.forward().
Expected behavior
I think this can have an impact on ppo training as it may ignore some tokens's probility. And the possible solution is that we should use the policy/ref model to forward again to recompute the logits over the sentence.
Checklist