Closed idanshen closed 1 day ago
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
I was suspicious about this at first sight but finally I think this is normal. Here is how I would explain that:
If you look at step()
method in PPOTrainer
class, you see there are two main parts in this training. The first one performs forward path on language model and reference language model, hence there is no backprop made at this stage since we just want the outputs to apply the policy. In a second phase, starting line 807, train_minibatch()
method is called and here the model is in train mode, which is logical because the loss is computed here.
Hope this helps!
I think the .eval()
is mainly to ensure the generation log probs and the forward logprobs are the same. In theory we could do .train()
and just disable the dropout.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
When performing PPO step, the code perform the forward pass in line 798 using the function "batched_forward_pass". However, "batched_forward_pass" put the model in eval mode (line 986):
I'm pretty sure this is not intended and can lead to bugs.