mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.84k stars 503 forks source link

Managing Timeout on Training Errors and Simultaneous Restart of All Nodes in LLM Foundry #1272

Closed germanjke closed 3 weeks ago

germanjke commented 3 weeks ago

Problem Description:

In the LLM Foundry environment running on Kubernetes, when a hardware failure or any other error occurs on one node during training, the default behavior is to trigger a restart only on that specific node. However, your requirement is to ensure that all nodes restart simultaneously to minimize downtime and maintain the integrity of the training process.

Proposed Solution:

To address this issue, you aim to configure Kubernetes to restart all nodes simultaneously upon encountering errors during training. This can be achieved by adjusting the Kubernetes configuration to specify restartPolicy: Always. However, it's crucial to ensure that errors are propagated to all nodes to trigger simultaneous restarts.

In conjunction with Kubernetes configuration, introducing a new option such as dist_failure_timeout in the LLM Foundry configuration would allow for managing timeouts for errors during training. By setting a relatively low timeout value, such as 10, in the configuration, you ensure that if any node encounters an error, all nodes will restart simultaneously after the specified timeout period, thereby minimizing downtime.

Example Usage (Kubernetes Configuration):

spec:
  restartPolicy: Always

Example Usage (LLM Foundry Configuration):

dist_failure_timeout: 10

This combined approach would effectively manage errors during training, ensuring simultaneous restarts of all nodes in LLM Foundry running on Kubernetes, thus optimizing performance and minimizing downtime.

dakinggg commented 3 weeks ago

I don't really understand this issue. LLM Foundry has nothing to do with Kubernetes. When a run crashes, all ranks will crash (in some edge cases some ranks may hang), but either way the whole job needs to be taken down and restarted. LLM Foundry has a config parameter dist_timeout that will get passed along to PyTorch for timing out distributed operations. Please feel free to open a new issue with additional information if needed.