huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.28k stars 1.16k forks source link

Questions about the implementation details of batched_forward_pass and respond_to_batch #15

Closed xiaoda99 closed 3 years ago

xiaoda99 commented 3 years ago

Great work! I have some questions about the implementation details. 1) In PPOTrainer.batched_forward_pass, why is values shifted one step backward compared to logprobs and ref_logprobs, which is done in the following code: https://github.com/lvwerra/trl/blob/master/trl/ppo.py#L178-L185

2) Another related question, in respond_to_batch, when generating next tokens, why not compute logprob and value at each step at the same time? If respond_to_batch returns logprobs and values, the first forward in batched_forward_pass would be unnecessary. (https://github.com/lvwerra/trl/blob/master/trl/ppo.py#L180) I noticed that the official implementation by OpneAI (https://github.com/openai/lm-human-preferences) adopts this strategy.

3) respond_to_batch does not use incremental decoding with cached hidden states. I guees it would be slow for long responses. Is incremental decoding possible within HuggingFace Transformers framework?

lvwerra commented 3 years ago
  1. The model outputs predictions for the next token whereas the log_probs are the log probabilities for the current token. This simply aligns the two.

  2. The main motivation was to decouple the generation from the training as much as possible. Since it takes a fraction of the time of the backward pass the speedup would be minimal. That way the PPOTrainer interface is cleaner.

  3. That's possible. It could be that the transformer function generate handles this, but I had to implement my own, simple decoding function since the model would exploit several aspects of it. See the comments here about the custom response function. Feel free to make a PR if you can fix the weaknesses and improve the performance.

Cheers, Leandro