huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.24k stars 223 forks source link

EarlyStoppingCallback early_stopping_patience_counter and Trainer.state not reset between hyperparameter search trials #502

Open chrisaballard opened 8 months ago

chrisaballard commented 8 months ago

If using the EarlyStoppingCallback from the transformers package, and running the Trainer.hyperparameter_search method, early_stopping_patience_counter on the callback instance and Trainer.state are not reset between subsequent trials.

This means that when running the second and subsequent trials, the early_stopping_patience_counter is greater than the early_stopping_patience. This causes Trainer.control.should_training_stop to be set to True the first time on_evaluate is run on the callbacks. The trial is then terminated.

To resolve this issue, I suggest that when calling Trainer.train(trial=...) and passing in a trial, Trainer.callback_hander should be re-initialised and Trainer.state initialised to TrainerState().

chrisaballard commented 8 months ago

The state and callback handler are defined in Trainer.__init__():

https://github.com/huggingface/setfit/blob/73d8646e8c49d7926a5bac0ff0eb9f0d3bf96f75/src/setfit/trainer.py#L235C9-L238C94

Create a new Trainer instance with an EarlyStoppingCallback:

trainer = Trainer(
    args=training_arguments,
    metric=metric,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    column_mapping=column_mapping,
    model_init=model_init(...),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

Run hyperparameter search:

trainer.hyperparameter_search(
    direction="maximize", hp_space=hp_space, n_trials=n_trials, **kwargs
)

The early_stopping_patience_counter is stored in the EarlyStoppingCallback instance and not reset to 0 between trials.

The on_evaluate_method ofEarlyStoppingCallbackcompares the current eval loss tostate.best_metric`:

https://github.com/huggingface/transformers/blob/f4364a6ff16e33186cb40f1d3fafd3792556d1b8/src/transformers/trainer_callback.py#L566

Because Trainer.state is not reset between trials, Trainer.state.best_metric has the most recent value from the last trial.

Because EarlyStoppingCallback.early_stopping_patience_counter is not reset between trials, it is incremented each time the callback is evaluated across all runs.