Open demon2036 opened 1 week ago
I'm facing the same problem..
Hi Jax team,
We are also facing this issue with TPUs v2, v3 and v4. However, v5p's are ok.
The issue should be resolved by now, could you please double check on a newly created TPU VM and let us know if it's still happening?
The issue should be resolved by now, could you please double check on a newly created TPU VM and let us know if it's still happening?
@gagika Thank you for the update! I've tested on a newly created TPU VM, and the issue is now resolved. Everything is working as expected.
Description
Hi JAX team,
In the past two days, I've been using GCP's queued-resources to create spot TPU v4-256/v4-64, and then running the following Python script.
However, I found that it gets stuck at the jax.distributed.initialize() command. This is very strange because when I created an on-demand TPU v4-64 two weeks ago, the jax.distributed.initialize() command executed without any issues, and it still works fine on that machine. But now, with the newly created instances, I'm facing this problem. Therefore, I'd like to seek help from the JAX team !
BUG
pip list
Setup bash
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34 jaxlib: 0.4.34 numpy: 2.1.2 python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] jax.devices (128 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=126, process_index=31, coords=(2,3,7), core_on_chip=0) TpuDevice(id=127, process_index=31, coords=(3,3,7), core_on_chip=0)] process_count: 32 platform: uname_result(system='Linux', node='t1v-n-db3292ae-w-17', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')