Open blazorin opened 9 months ago
After a decent amount of time, the following exception is thrown:
Traceback (most recent call last):
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 679, in backends
backend = _init_backend(platform)
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 761, in _init_backend
backend = registration.factory()
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 157, in tpu_client_timer_callback
client = xla_client.make_tpu_client(_get_tpu_library_path())
File "/home/alberta/.local/lib/python3.10/site-packages/jaxlib/xla_client.py", line 198, in make_tpu_client
return make_tfrt_tpu_c_api_client()
File "/home/alberta/.local/lib/python3.10/site-packages/jaxlib/xla_client.py", line 129, in make_tfrt_tpu_c_api_client
return _xla.get_c_api_client('tpu', options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.0.14:8471.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/environment_info.py", line 44, in print_environment_info
devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 872, in devices
return get_backend(backend).devices()
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 806, in get_backend
return _get_backend_uncached(platform)
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 786, in _get_backend_uncached
bs = backends()
File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 695, in backends
raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.0.14:8471. (set JAX_PLATFORMS='' to automatically choose an available backend)
I've encountered the same issue when using TPU v4-32.
According to the warning, you should run your code on all TPU hosts. As TPU v4-32 has 4 hosts, you should run the code simultaneously on all 4 hosts. You can refer to using-tpu-pod for help.
I had the same issue.
However, when I switched to this code
gcloud compute tpus tpu-vm ssh tpu-vm-name --zone=europe-west4-a --worker=all --command="python3 -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'"
it worked.
I think the key is to use jax.distributed.initialize()
I've encountered this warning on v3-32 as well, but after some waiting, script was executed without any errors. Seems like it just takes some time sometimes.
To configure TPU devices and run commands across hosts for TPU Pods, you can now use tpux!
Same problem here.
Description
I have been trying multiple attempts to run a simple JAX script, but it keeps stuck. I am using TPU v4-32, on ubuntu 22.04 and Python 3.10.
jax_test.py
System info (python version, jaxlib version, accelerator, etc.)
Same behaviour on jax.print_environment_info().