NVIDIA / NeMo-Aligner

Scalable toolkit for efficient model alignment
Apache License 2.0
522 stars 58 forks source link

Added support for float values for val_check_interval to SFT #202

Closed trias702 closed 3 months ago

trias702 commented 3 months ago

What does this PR do ?

Adds support for float values for val_check_interval for SFT. It also adds support for floats/ints for limit_train_batches to SFT and DPO, as per the usage in PTL

This was requested by @Kipok

Changelog

Usage

val_check_interval = 0.25   # means you will run validation 4 times per epoch
val_check_interval = 100    # means you will run validation every 100 steps of training
limit_train_batches = 0.5   # you will only use 50% of your training data per epoch
limit_train_batches = 100  # you will only consume 100 steps of your train dataloader per epoch

All possibilities can be used for SFT, DPO, and SPIN

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing a new algorithm

Additional Information