pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.05k stars 372 forks source link

Validation and early stopping during training #883

Open kinggongzilla opened 5 months ago

kinggongzilla commented 5 months ago

Is there a way to evaluate the model performance during training on a validation dataset and only save a new checkpoint if it achieves lower validation loss?

rohan-varma commented 5 months ago

Hi @kinggongzilla, thanks for filing this issue!

Currently we only support an early stopping that's based on the # of steps taken in an epoch, i.e. you can set max_steps_per_epoch flag in the configuration to early stop your model based on a # of steps.

However this doesn't satisfy your use case of only early stopping / saving a checkpoint based on some validation results.

In training evaluation + stopping criteria based on evaluation is a large space we haven't looked deeply into, what do you folks think @ebsmothers @RdoubleA? I could see a future in which we allow users to specify a validation dataset or validation split, and incorporate validation metrics into our checkpointer for whether to save a checkpoint or not. This is definitely something we could look at enabling in the future if there's sufficient interest.

kinggongzilla commented 5 months ago

Thanks for the quick reply! Being able to define a validation dataset and do early stoppingbased on the validation loss would definitely be super helpful.

optimass commented 4 months ago

+1 this would be super useful.

Some-random commented 4 months ago

+1 Would be super useful!

ebsmothers commented 4 months ago

Thanks all for the comments. This feature (along with general validation loops) are fairly high on our wishlist right now. We still need to do a bit of design to make sure it's not too intrusive into our recipes, but definitely hear you on the need for this feature. We will keep you posted here!