axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.19k stars 780 forks source link

Automatic resumption with WatchDog #1310

Open LeonardoEmili opened 5 months ago

LeonardoEmili commented 5 months ago

⚠️ Please check that this feature request hasn't been suggested before.

🔖 Feature description

Automatic resumption when WatchDog detects loss spikes.

It is a simple yet powerful feature that allows automatic resumption of broken trainings when WatchDog detects an abnormal loss value during training, in turn reducing the overhead of manually restarting the failed runs as this is taken care by the toolkit.

@winglian I am happy to work on a PR for this if you think it would be useful.

✔️ Solution

When watchdog is enabled and the the configuration loss_watchdog_automatic_resume=True is set, training will raise a LossWatchDogException, captured and a new training routine will be run again.

❓ Alternatives

  1. The existing train function in axolotls/train.py could be wrapped so that training is run at the beginning and when a LossWatchDogException is captured (dirty)
  2. Trainer is aware of the WatchDog auto-resume logic and when a LossWatchDogException is raised, it will reload the training state as well as weights from the latest checkpoint, enabling seamless resume when loss spikes are detected

📝 Additional Context

A similar solution by MosaicML: Automatic resumption from node failures and loss spikes. No need to babysit model training.

Acknowledgements

winglian commented 5 months ago

Can you expand more on how the restart from a loss spike would prevent a new loss spike?

LeonardoEmili commented 5 months ago

It has happened many times in the past with different toolkits that fine-tuning LLMs results in failed trainings due to loss spikes, the same happened here. Note this is a common issue to which people haven't found a proper solution yet but manually restarting the run as it seems these spikes are non-deterministic and do not depend on specific data points, training status nora combination of the two (picture shows mistral SFT on 8x A100 where the orange and blue curves overlap perfectly until the orange experiences a small increase in the training loss, followed by a spike loss).

loss_spike

While, restarting the same run from scratch shows that the setup is reproducible, manually restarting the failed run allows fine-tuning the model further (the training in blue was restarted achieving a smaller loss as depicted in orange).

loss_restart

Existing implementation, such as MosaicML Auto Resumption policy implements exactly this feature and removes the burden from the user to manually exec into the node and restart the failed run (see docs).

Finally, it seems that people often report this issue for Mistral-based models, see some related issues:

TL;DR: To answer your question, restarting a loss that experienced a loss spike does NOT guarantee new loss spikes to happen in the future but enables faster convergence and more stable training, which is desirable feature to add when unstable runs are detected.