huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.13k stars 1.28k forks source link

Question about the logprobs of the policy-generated sentences in PPO trainer #2358

Open yanghh2000 opened 1 week ago

yanghh2000 commented 1 week ago

System Info

trl==0.12.0

Information

Tasks

Reproduction

In the implementation of PPO trainer, the logits of the policy-generated sentences are inherited from outputs of the model's geneate(), i.e., output.scores. However, the returned logit scores are processed or filtered with the transformers's LogitsProcessor (e.g., Temperature, Top-k, Top-p), meaning they are not the original logits from model.forward().

with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
  query_responses, logitss = batch_generation(
      unwrapped_model.policy,
      queries,
      args.local_rollout_forward_batch_size,
      processing_class.pad_token_id,
      generation_config,
  )

for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
  query = queries[i : i + args.local_rollout_forward_batch_size]
  query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
  response = query_response[:, context_length:]
  logits = logitss[i : i + args.local_rollout_forward_batch_size]
  all_logprob = F.log_softmax(logits, dim=-1)
  logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
  del logits, all_logprob
  torch.cuda.empty_cache()

Expected behavior

I think this can have an impact on ppo training as it may ignore some tokens's probility. And the possible solution is that we should use the policy/ref model to forward again to recompute the logits over the sentence.

Checklist