There is a problem with the implementation of the update_train_state function in chapters 3-5. Specifically, when the loss is decreasing train_state['early_stopping_best_val'] is not updated (except in the first epoch), so the early stopping criteria can only be fulfilled if the loss gets higher than in the first epoch.
# If loss worsened
if loss_t >= train_state['early_stopping_best_val']:
# Update step
train_state['early_stopping_step'] += 1
# Loss decreased
else:
# Save the best model
if loss_t < train_state['early_stopping_best_val']:
torch.save(model.state_dict(), train_state['model_filename'])
# Reset early stopping step
train_state['early_stopping_step'] = 0
Please add the line train_state['early_stopping_best_val'] = loss_t, like in later chapters.
There is a problem with the implementation of the
update_train_state
function in chapters 3-5. Specifically, when the loss is decreasingtrain_state['early_stopping_best_val']
is not updated (except in the first epoch), so the early stopping criteria can only be fulfilled if the loss gets higher than in the first epoch.Please add the line
train_state['early_stopping_best_val'] = loss_t
, like in later chapters.