huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.25k stars 26.09k forks source link

Trainer not keeping best model checkpoint with save_total_limit=1 #15089

Closed erickrf closed 2 years ago

erickrf commented 2 years ago

Environment info

Who can help

@sgugger

Information

Model I am using (Bert, XLNet ...): BERT

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

  1. Create a trainer with save_total_limit=2 and load_best_model_at_end=True
  2. Train the model

After each evaluation, the trainer will save the most recent checkpoint and delete the previous one to keep the save total limit, even if the previous one was better.

That is not what I expected, considering this comment.

Expected behavior

I'd expect the best model to be always kept on disk.

sgugger commented 2 years ago

I am unsure what behavior you are seeing, but load_best_model_at_end=True makes sure the best model checkpoint is always kept. That means the absolute best model checkpoint, so if at step 500, you get a model worse then at step 450, and the best model checkpoint was at step 350, the Trainer will delete the checkpoint at step 450 indeed, and only keep the checkpoint at step 350 for the best model.

erickrf commented 2 years ago

I see it now, it was actually my fault. I forgot to provide the metric_for_best_model and the trainer was only considering the loss. Sorry for the misunderstanding!

sgugger commented 2 years ago

No problem!

jbmaxwell commented 2 years ago

I'm still confused by this. I'm not able to use your example of save_total_limit=2 and load_best_model_at_end=True, because it fails with:

ValueError: --load_best_model_at_end requires the save and eval strategy to match, but found
- Evaluation strategy: IntervalStrategy.NO
- Save strategy: IntervalStrategy.STEPS

Ideally, I'd like to save the N best checkpoints, but I can't find a way to do that.

I'm on transformers 4.18.0.

UPDATE: Also tried on 4.19.0.dev0

sgugger commented 2 years ago

What don't you understand in the error message? As stipulated, you need to have the same evaluation and save strategy when activating load_best_model_at_end.

jbmaxwell commented 2 years ago

Of course, the error message is clear. What I was confused about is how to find a combination of settings that would do what I was wanting to do. Ideally, I wanted to save the N best checkpoints at a given frequency of steps (e.g., 100 steps). It's helpful, particularly as my model gets closer to converging, to be able to scp checkpoints out to my laptop to try in the context of my actual application. And I often do that while the model's still training on my server, so it's nice to have frequent saves where I can quickly find a recent minimal loss.

Anyway, I'm now using the settings suggested here: https://discuss.huggingface.co/t/save-only-best-model-in-trainer/8442/8?u=jbmaxwell, which is fine:

save_total_limit = 2
save_strategy = “no”
load_best_model_at_end=False

Though honestly it's still quite counterintuitive, as from the settings alone it looks like this will save only the 2 most recent checkpoints, while apparently it will save the most recent and the best.

gitjoop commented 1 year ago

@sgugger Is it possible update the documentation of the save_total_limit parameter of Trainer?

What I think would make the documentation better is to document the following behaviour:

Maybe this will help others understand the behaviour straight from the docs. Side note: the docs are generally awesome!

sgugger commented 1 year ago

@gitjoop Feel free to open a PR, it would indeed be awesome to have all of this in the doc :-)

y1450 commented 1 year ago

I tried running run_mlm.py in examples, with MLFlow Callback but apparently no checkpoint is logged in mlflow. I used the following config.

save_total_limit = 2
save_strategy = “no”
load_best_model_at_end=False

and set the HF_MLFLOW_LOG_ARTIFACTS to 1. I can confirm it does log checkpoint but does not log the with the above configuration.