delip / PyTorchNLPBook

Code and data accompanying Natural Language Processing with PyTorch published by O'Reilly Media https://amzn.to/3JUgR2L
Apache License 2.0
1.98k stars 807 forks source link

Bug in update_train_state #20

Open karoly-hars opened 5 years ago

karoly-hars commented 5 years ago

There is a problem with the implementation of the update_train_state function in chapters 3-5. Specifically, when the loss is decreasing train_state['early_stopping_best_val'] is not updated (except in the first epoch), so the early stopping criteria can only be fulfilled if the loss gets higher than in the first epoch.

# If loss worsened
if loss_t >= train_state['early_stopping_best_val']:
    # Update step
    train_state['early_stopping_step'] += 1
# Loss decreased
else:
    # Save the best model
    if loss_t < train_state['early_stopping_best_val']:
        torch.save(model.state_dict(), train_state['model_filename'])

    # Reset early stopping step
    train_state['early_stopping_step'] = 0

Please add the line train_state['early_stopping_best_val'] = loss_t, like in later chapters.