OpenNMT / OpenNMT-py

Open Source Neural Machine Translation and (Large) Language Models in PyTorch
https://opennmt.net/
MIT License
6.77k stars 2.25k forks source link

keep_checkpoint option removes best performing model #1946

Open BramVanroy opened 3 years ago

BramVanroy commented 3 years ago

With the keep_checkpoint option we can specify how many checkpoints should be kept. However, the checkpoints are just saved sequentially and never ordered. That means that if your best performing model is early on, it might get removed anyway.

https://github.com/OpenNMT/OpenNMT-py/blob/073428849c1d10dd4fae7f8fd92699cdc9f230a4/onmt/models/model_saver.py#L79-L83

As an alternative approach, I would suggest that if validation is done before each save step, that validation loss is also passed to the save method. self.checkpoint_queue could then contain tuples of (loss, chkpt_name) and after each append that queue gets sorted on loss. That way, only the worst performing models are removed.

Things to consider: ModelSaver should then know whether the metric is higher=better or lower=better, and a fallback needs to be in-place when no loss is passed.

francoishernandez commented 3 years ago

Hey Bram, Yes, there is a pending PR #1859 and issue #1856 about this, but the first propositions did not convince me, and I did not take the time yet to have a deeper look. I like the idea of having a queue that gets updated each time. Maybe some people would like to still keep the N last checkpoints (chronologically), thought, but that may be handled with a flag, like keep_checkpoints_order with choices=[metricA, metricB, "chronological"] for instance. Feel free to open a PR for such an implementation.

BramVanroy commented 3 years ago

@francoishernandez Ah, my bad - should've looked in the PRs.

In your proposal, you use metricA and metricB. In OpenNMT, can validation ever be done with more than one metric? From build_loss_compute it seems that there is only ever one metric during validation.

I would simply (optionally) pass the validation loss to the model saver here:

https://github.com/OpenNMT/OpenNMT-py/blob/bc95e03875aabf4363161aac70830ecbdb762d91/onmt/trainer.py#L276-L279

If no loss is passed (i.e. if no validation is done), the behaviour defaults to "chronological" and otherwise the validation metric is used. If that seems like something that interests you I can implement this next week.

francoishernandez commented 3 years ago

This was just a way of keeping it generic for any further extension. A simple boolean flag can indeed do the trick at first.