jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.62k stars 2.82k forks source link

TPU backend gets stuck #19971

Open blazorin opened 9 months ago

blazorin commented 9 months ago

Description

python jax_test.py 

/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py:146: UserWarning: TPU backend initialization is taking more than 60.0 seconds. Did you run your code on all TPU hosts? See https://jax.readthedocs.io/en/latest/multi_process.html for more information.
  warnings.warn(

I have been trying multiple attempts to run a simple JAX script, but it keeps stuck. I am using TPU v4-32, on ubuntu 22.04 and Python 3.10.

jax_test.py

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)`

System info (python version, jaxlib version, accelerator, etc.)

Same behaviour on jax.print_environment_info().

blazorin commented 9 months ago

After a decent amount of time, the following exception is thrown:

Traceback (most recent call last):
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 679, in backends
    backend = _init_backend(platform)
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 761, in _init_backend
    backend = registration.factory()
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 157, in tpu_client_timer_callback
    client = xla_client.make_tpu_client(_get_tpu_library_path())
  File "/home/alberta/.local/lib/python3.10/site-packages/jaxlib/xla_client.py", line 198, in make_tpu_client
    return make_tfrt_tpu_c_api_client()
  File "/home/alberta/.local/lib/python3.10/site-packages/jaxlib/xla_client.py", line 129, in make_tfrt_tpu_c_api_client
    return _xla.get_c_api_client('tpu', options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.0.14:8471.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/environment_info.py", line 44, in print_environment_info
    devices_short = str(np.array(xla_bridge.devices())).replace('\n', '')
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 872, in devices
    return get_backend(backend).devices()
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 806, in get_backend
    return _get_backend_uncached(platform)
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 786, in _get_backend_uncached
    bs = backends()
  File "/home/alberta/.local/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 695, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.130.0.14:8471. (set JAX_PLATFORMS='' to automatically choose an available backend)
DHdroid commented 9 months ago

I've encountered the same issue when using TPU v4-32.

yixiaoer commented 9 months ago

According to the warning, you should run your code on all TPU hosts. As TPU v4-32 has 4 hosts, you should run the code simultaneously on all 4 hosts. You can refer to using-tpu-pod for help.

onurgu commented 8 months ago

I had the same issue.

However, when I switched to this code

gcloud compute tpus tpu-vm ssh tpu-vm-name --zone=europe-west4-a   --worker=all   --command="python3 -c 'import jax; jax.distributed.initialize(); jax.process_index() == 0 and print(jax.devices())'"

it worked.

I think the key is to use jax.distributed.initialize()

defdet commented 8 months ago

I've encountered this warning on v3-32 as well, but after some waiting, script was executed without any errors. Seems like it just takes some time sometimes.

yixiaoer commented 5 months ago

To configure TPU devices and run commands across hosts for TPU Pods, you can now use tpux!

s-smits commented 4 months ago

Same problem here.