jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.38k stars 2.79k forks source link

jax.distributed.initialize() crash #24399

Open demon2036 opened 1 week ago

demon2036 commented 1 week ago

Description

Hi JAX team,

In the past two days, I've been using GCP's queued-resources to create spot TPU v4-256/v4-64, and then running the following Python script.

import jax
jax.distributed.initialize()
print(1)

However, I found that it gets stuck at the jax.distributed.initialize() command. This is very strange because when I created an on-demand TPU v4-64 two weeks ago, the jax.distributed.initialize() command executed without any issues, and it still works fine on that machine. But now, with the newly created instances, I'm facing this problem. Therefore, I'd like to seek help from the JAX team !

BUG


    jax.distributed.initialize()
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 231, in initialize
    global_state.initialize(coordinator_address, num_processes, process_id,
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/distributed.py", line 55, in initialize
    clusters.ClusterEnv.auto_detect_unset_distributed_params(
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cluster.py", line 82, in auto_detect_unset_distributed_params
    process_id = env.get_process_id()
                 ^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 144, in get_process_id
    slice_id = cls._get_slice_id()
               ^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 159, in _get_slice_id
    if has_megascale_address():
       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 74, in has_megascale_address
    return get_tpu_env_value('MEGASCALE_COORDINATOR_ADDRESS') is not None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 71, in get_tpu_env_value
    return value if value is not None else get_tpu_env_value_from_metadata(key)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 59, in get_tpu_env_value_from_metadata
    tpu_env_data = get_metadata('tpu-env')[0]
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/jax/_src/clusters/cloud_tpu_cluster.py", line 45, in get_metadata
    api_resp = requests.get(
               ^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 73, in get
    return request("get", url, params=params, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/api.py", line 59, in request
    return session.request(method=method, url=url, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 589, in request
    resp = self.send(prep, **send_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/sessions.py", line 703, in send
    r = adapter.send(request, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/requests/adapters.py", line 700, in send
    raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/tpu-env (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f9ac8f53c50>: Failed to establish a new connection: [Errno 111] Connection refused'))

pip list

Package                 Versionw-3:~# 
----------------------- -----------------------
anaconda-anon-usage     0.4.4
archspec                0.2.3
boltons                 23.0.0
Brotli                  1.0.9
certifi                 2024.7.4
cffi                    1.16.0
charset-normalizer      3.3.2
conda                   24.7.1
conda-content-trust     0.2.0
conda-libmamba-solver   24.7.0
conda-package-handling  2.3.0
conda_package_streaming 0.10.0
cryptography            42.0.5
distro                  1.9.0
frozendict              2.4.2
idna                    3.7
jax                     0.4.34
jaxlib                  0.4.34
jsonpatch               1.33
jsonpointer             2.1
libmambapy              1.5.8
libtpu-nightly          0.1.dev20241002+nightly
menuinst                2.1.2
ml_dtypes               0.5.0
numpy                   2.1.2
opt_einsum              3.4.0
packaging               24.1
pip                     24.2
platformdirs            3.10.0
pluggy                  1.0.0
pycosat                 0.6.6
pycparser               2.21
PySocks                 1.7.1
requests                2.32.3
ruamel.yaml             0.17.21
scipy                   1.14.1
setuptools              72.1.0
tqdm                    4.66.4
truststore              0.8.0
urllib3                 2.2.2
wheel                   0.43.0
zstandard               0.22.0

Setup bash

rm -rf ~/miniconda3

wget https://repo.anaconda.com/miniconda/Miniconda3-py311_24.7.1-0-Linux-x86_64.sh
bash Miniconda3-py311_24.7.1-0-Linux-x86_64.sh -b -u
rm Miniconda3-py311_24.7.1-0-Linux-x86_64.sh

~/miniconda3/bin/conda init bash
eval "$(~/miniconda3/bin/conda shell.bash hook)"

# 2. Install requirements.
pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.34 jaxlib: 0.4.34 numpy: 2.1.2 python: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] jax.devices (128 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=126, process_index=31, coords=(2,3,7), core_on_chip=0) TpuDevice(id=127, process_index=31, coords=(3,3,7), core_on_chip=0)] process_count: 32 platform: uname_result(system='Linux', node='t1v-n-db3292ae-w-17', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

thiagolaitz commented 1 week ago

I'm facing the same problem..

rodrigo-f-nogueira commented 1 week ago

Hi Jax team,

We are also facing this issue with TPUs v2, v3 and v4. However, v5p's are ok.

gagika commented 1 week ago

The issue should be resolved by now, could you please double check on a newly created TPU VM and let us know if it's still happening?

demon2036 commented 1 week ago

The issue should be resolved by now, could you please double check on a newly created TPU VM and let us know if it's still happening?

@gagika Thank you for the update! I've tested on a newly created TPU VM, and the issue is now resolved. Everything is working as expected.