huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.87k stars 1.09k forks source link

Bug in calling model.eval() in PPO #1569

Closed idanshen closed 1 day ago

idanshen commented 3 months ago

When performing PPO step, the code perform the forward pass in line 798 using the function "batched_forward_pass". However, "batched_forward_pass" put the model in eval mode (line 986):

    model.eval()

I'm pretty sure this is not intended and can lead to bugs.

github-actions[bot] commented 2 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

B-Gendron commented 1 month ago

I was suspicious about this at first sight but finally I think this is normal. Here is how I would explain that:

If you look at step() method in PPOTrainer class, you see there are two main parts in this training. The first one performs forward path on language model and reference language model, hence there is no backprop made at this stage since we just want the outputs to apply the policy. In a second phase, starting line 807, train_minibatch() method is called and here the model is in train mode, which is logical because the loss is computed here.

Hope this helps!

vwxyzjn commented 1 month ago

I think the .eval() is mainly to ensure the generation log probs and the forward logprobs are the same. In theory we could do .train() and just disable the dropout.

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.