Open LeonardoEmili opened 5 months ago
Can you expand more on how the restart from a loss spike would prevent a new loss spike?
It has happened many times in the past with different toolkits that fine-tuning LLMs results in failed trainings due to loss spikes, the same happened here. Note this is a common issue to which people haven't found a proper solution yet but manually restarting the run as it seems these spikes are non-deterministic and do not depend on specific data points, training status nora combination of the two (picture shows mistral SFT on 8x A100 where the orange and blue curves overlap perfectly until the orange experiences a small increase in the training loss, followed by a spike loss).
While, restarting the same run from scratch shows that the setup is reproducible, manually restarting the failed run allows fine-tuning the model further (the training in blue was restarted achieving a smaller loss as depicted in orange).
Existing implementation, such as MosaicML Auto Resumption policy implements exactly this feature and removes the burden from the user to manually exec into the node and restart the failed run (see docs).
Finally, it seems that people often report this issue for Mistral-based models, see some related issues:
TL;DR: To answer your question, restarting a loss that experienced a loss spike does NOT guarantee new loss spikes to happen in the future but enables faster convergence and more stable training, which is desirable feature to add when unstable runs are detected.
⚠️ Please check that this feature request hasn't been suggested before.
🔖 Feature description
Automatic resumption when WatchDog detects loss spikes.
It is a simple yet powerful feature that allows automatic resumption of broken trainings when WatchDog detects an abnormal loss value during training, in turn reducing the overhead of manually restarting the failed runs as this is taken care by the toolkit.
@winglian I am happy to work on a PR for this if you think it would be useful.
✔️ Solution
When watchdog is enabled and the the configuration
loss_watchdog_automatic_resume=True
is set, training will raise aLossWatchDogException
, captured and a new training routine will be run again.❓ Alternatives
train
function inaxolotls/train.py
could be wrapped so that training is run at the beginning and when aLossWatchDogException
is captured (dirty)LossWatchDogException
is raised, it will reload the training state as well as weights from the latest checkpoint, enabling seamless resume when loss spikes are detected📝 Additional Context
A similar solution by MosaicML:
Automatic resumption from node failures and loss spikes. No need to babysit model training
.Acknowledgements