Open chrisaballard opened 8 months ago
The state and callback handler are defined in Trainer.__init__()
:
Create a new Trainer instance with an EarlyStoppingCallback
:
trainer = Trainer(
args=training_arguments,
metric=metric,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
column_mapping=column_mapping,
model_init=model_init(...),
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
Run hyperparameter search:
trainer.hyperparameter_search(
direction="maximize", hp_space=hp_space, n_trials=n_trials, **kwargs
)
The early_stopping_patience_counter
is stored in the EarlyStoppingCallback
instance and not reset to 0 between trials.
The on_evaluate_method of
EarlyStoppingCallbackcompares the current eval loss to
state.best_metric`:
Because Trainer.state
is not reset between trials, Trainer.state.best_metric
has the most recent value from the last trial.
Because EarlyStoppingCallback.early_stopping_patience_counter
is not reset between trials, it is incremented each time the callback is evaluated across all runs.
If using the EarlyStoppingCallback from the transformers package, and running the
Trainer.hyperparameter_search
method,early_stopping_patience_counter
on the callback instance andTrainer.state
are not reset between subsequent trials.This means that when running the second and subsequent trials, the
early_stopping_patience_counter
is greater than theearly_stopping_patience
. This causesTrainer.control.should_training_stop
to be set toTrue
the first timeon_evaluate
is run on the callbacks. The trial is then terminated.To resolve this issue, I suggest that when calling
Trainer.train(trial=...)
and passing in a trial,Trainer.callback_hander
should be re-initialised andTrainer.state
initialised toTrainerState()
.