stanford-crfm / levanter

Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax
https://levanter.readthedocs.io/en/latest/
Apache License 2.0
515 stars 81 forks source link

More robust and dynamic multi-slicing TPU training #690

Open Ivan-Zhou opened 2 months ago

Ivan-Zhou commented 2 months ago

The current way multi-slicing training is not robust or reliable on spot instances.

There have been some discussion inside CRFM and with GCP team on this topic. I create this issue to capture the main ideas and threads.

Main Objectives

Challenges

Possible Ideas

Use Ray for scheduling

Allen Wang from Google proposed to use Ray to schedule and run workloads through slices. He put together a quick gist on how to run both single and multi-slice workloads via Ray (>= 2.10.0). This covers the job scheduling aspect and will work regardless if the cluster is provisioned directly on VMs or on GKE.

To mitigate potential race conditions, Allen also added placement groups to pre-reserve existing TPU pod slices (ray_tpu.py) and an example of how it can be used to run tasks (ray_tpu_task.py)

David's summarization:

Use a host to coordinate work and communicate gradients

@dlwh 's idea:

dlwh commented 2 months ago

These two ideas can be fused pretty well I think, fwiw.

Ivan-Zhou commented 2 months ago

I could reproduce Allen's script on single slice v4 TPU, but not on multi-slices. It should not be a blocker for now, if we are not prioritizing multi-slice training.

Now I think more of it, I realized that this is not the shortest path. I should instead take reference of Marin's existing Ray + TPU framework for launching data preparation jobs. It seems to be a more applicable guide.

I will try to follow https://github.com/stanford-crfm/marin/tree/main/infra#maintaining-a-ray-cluster and build a PoC for training.

dlwh commented 2 months ago

sounds like a good plan! The main difference is that Allen's script handles multi-node TPU slices, which the marin cluster doesn't bother with.

Ivan-Zhou commented 2 months ago

I will run experiment with larger TPU nodes then