google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.66k stars 2.71k forks source link

CTRL+C Broken When Running distributed.initialize() on one TPU Host by Accident #22436

Open s-smits opened 1 month ago

s-smits commented 1 month ago

Issue

Encountered a deadlock while running a JAX-based LLM training script on a TPU-v4-32 pod. SSH'd into worker 0 and ran the script there directly, instead of using --worker all --command "...". CTRL+C should be able to kill the process, but it didn't in this case.

Command

python -m EasyLM.models.llama.llama_train \
    --mesh_dim='-1,4,1' \
    --dtype='fp16' \
    --total_steps=250000 \
    --log_freq=50 \
    --save_model_freq=0 \
    --save_milestone_freq=2500 \
    --load_llama_config='7b' \
    --update_llama_config='' \
    --load_dataset_state='' \
    --load_checkpoint='' \
    --optimizer.type='adamw' \
    --optimizer.adamw_optimizer.weight_decay=0.1 \
    --optimizer.adamw_optimizer.lr=3e-4 \
    --optimizer.adamw_optimizer.end_lr=3e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=2000 \
    --optimizer.adamw_optimizer.lr_decay_steps=250000 \
    --train_dataset.type='json' \
    --train_dataset.text_processor.fields='text' \
    --train_dataset.json_dataset.path='togethercomputer/RedPajama-Data-1T' \
    --train_dataset.json_dataset.seq_length=2048 \
    --train_dataset.json_dataset.batch_size=2048 \
    --train_dataset.json_dataset.tokenizer_processes=16 \
    --checkpointer.save_optimizer_state=True \
    --logger.online=True \
    --logger.prefix='EasyLM' \
    --logger.project="open_llama_7b" \
    --logger.output_dir="." \
    --logger.wandb_dir="$HOME/experiment_output/open_llama_7b"

Error Message

/home/air/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:152: UserWarning: TPU backend initialization is taking more than 60.0 seconds. Did you run your code on all TPU hosts? See https://jax.readthedocs.io/en/latest/multi_process.html for more information.
  warnings.warn(

After this warning, the script hung indefinitely and CTRL+C did not kill it.

System info (python version, jaxlib version, accelerator, etc.)

Environment

young-geng commented 1 month ago

Author of EasyLM here. This is expected as you are only running the training script on a single host out of 4 hosts in a v4-32 pod. JAX uses a SPMD mode of execution, which means that you need to run the same script (same command) on all hosts in a TPU pod. You can manually ssh into each host and run the same command, or use gcloud ssh to do that by specifying --worker=all. If you are looking for something more user friendly, you can also checkout my recent TPU pod command package, which helps you launch jobs on TPU pods.

skye commented 1 month ago

@s-smits what do you think would be the best behavior here? We could eventually time out completely and exit with an error instead of hanging, but the timeout would have to be pretty long to make sure we don't accidentally quit too early on large deployments. I'm curious if you have more ideas!

s-smits commented 1 month ago

Yes, that sounds very tricky indeed. A time-out is not needed in my opinion. Being able to CTRL+C to cancel the TPU processes initiated with distributed.initialize() would solve the problem for me. I'd reckon the nesting depth is quite large, with multiple layers of TPU processes, such that even repeated CTRL+C attempts fail to fully terminate the entire process tree, however this would be desirable. If this is an unique edge case and you've never seen such a request before, I don't mind closing the Issue without a solution.

skye commented 1 month ago

Ah great point about CTRL+C not working! I've also run into this, although haven't investigated yet. I think it does have something to do with calling into the C++ TPU runtime.

Let's keep this issue open to track handling CTRL+C correctly in this situation. I took the liberty of editing your initial issue a bit to include this, so people don't have to read the whole thread (feel free to edit my edits more). Thanks for this feedback!