google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.09k stars 2.66k forks source link

jaxlib or libtpu not detected on TPU Pod #22070

Open ayaka14732 opened 1 week ago

ayaka14732 commented 1 week ago

Description

I have installed the TPU version of JAX (including jaxlib and libtpu) on all hosts of a TPU Pod inside a venv. Then, I run the following command on all hosts:

. ~/venv/bin/activate; python -c 'import jax; print(jax.devices())'

I got this error:

A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.4 (main, Jun  8 2024, 18:29:57) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-d95163c5-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
$ ls /dev/accel*
/dev/accel0  /dev/accel1  /dev/accel2  /dev/accel3
yashk2810 commented 1 week ago

How did you install JAX?

ayaka14732 commented 1 week ago

@yashk2810 I ran these commands on all hosts:

python3.12 -m venv ~/venv
. ~/venv/bin/activate
pip install -U pip
pip install -U wheel
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
yashk2810 commented 1 week ago

Can you show me your pip freeze?

ayaka14732 commented 1 week ago

@yashk2810

$ pip freeze
certifi==2024.6.2
charset-normalizer==3.3.2
idna==3.7
jax==0.4.30
jaxlib==0.4.30
libtpu-nightly==0.1.dev20240617
ml-dtypes==0.4.0
numpy==2.0.0
opt-einsum==3.3.0
requests==2.32.3
scipy==1.14.0
urllib3==2.2.2
wheel==0.43.0
skye commented 1 week ago

Can you try running JAX_DEBUG_LOG_MODULES=jax._src.xla_bridge python -c 'import jax; print(jax.devices())' and paste the output here?

ayaka14732 commented 1 week ago

@skye

DEBUG:2024-06-25 00:34:55,039:jax._src.xla_bridge:575: No jax_plugins namespace packages available
DEBUG:2024-06-25 00:34:55,049:jax._src.xla_bridge:969: Initializing backend 'cpu'
DEBUG:2024-06-25 00:34:55,109:jax._src.xla_bridge:981: Backend 'cpu' initialized
DEBUG:2024-06-25 00:34:55,109:jax._src.xla_bridge:969: Initializing backend 'cuda'
INFO:2024-06-25 00:34:55,109:jax._src.xla_bridge:889: Unable to initialize backend 'cuda': 
DEBUG:2024-06-25 00:34:55,109:jax._src.xla_bridge:969: Initializing backend 'rocm'
INFO:2024-06-25 00:34:55,109:jax._src.xla_bridge:889: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-06-25 00:34:55,109:jax._src.xla_bridge:969: Initializing backend 'tpu'
INFO:2024-06-25 00:34:55,142:jax._src.xla_bridge:889: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 176028. Not attempting to load libtpu.so in this process.
WARNING:2024-06-25 00:34:55,143:jax._src.xla_bridge:940: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
[CpuDevice(id=0)]
DEBUG:2024-06-25 00:34:55,812:jax._src.xla_bridge:575: No jax_plugins namespace packages available
DEBUG:2024-06-25 00:34:55,813:jax._src.xla_bridge:575: No jax_plugins namespace packages available
DEBUG:2024-06-25 00:34:55,836:jax._src.xla_bridge:575: No jax_plugins namespace packages available
DEBUG:2024-06-25 00:34:55,866:jax._src.xla_bridge:969: Initializing backend 'cpu'
DEBUG:2024-06-25 00:34:55,875:jax._src.xla_bridge:969: Initializing backend 'cpu'
DEBUG:2024-06-25 00:34:55,919:jax._src.xla_bridge:969: Initializing backend 'cpu'
DEBUG:2024-06-25 00:34:55,928:jax._src.xla_bridge:981: Backend 'cpu' initialized
DEBUG:2024-06-25 00:34:55,928:jax._src.xla_bridge:969: Initializing backend 'cuda'
INFO:2024-06-25 00:34:55,928:jax._src.xla_bridge:889: Unable to initialize backend 'cuda': 
DEBUG:2024-06-25 00:34:55,928:jax._src.xla_bridge:969: Initializing backend 'rocm'
INFO:2024-06-25 00:34:55,928:jax._src.xla_bridge:889: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-06-25 00:34:55,928:jax._src.xla_bridge:969: Initializing backend 'tpu'
DEBUG:2024-06-25 00:34:55,936:jax._src.xla_bridge:981: Backend 'cpu' initialized
DEBUG:2024-06-25 00:34:55,936:jax._src.xla_bridge:969: Initializing backend 'cuda'
INFO:2024-06-25 00:34:55,936:jax._src.xla_bridge:889: Unable to initialize backend 'cuda': 
DEBUG:2024-06-25 00:34:55,936:jax._src.xla_bridge:969: Initializing backend 'rocm'
INFO:2024-06-25 00:34:55,936:jax._src.xla_bridge:889: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-06-25 00:34:55,936:jax._src.xla_bridge:969: Initializing backend 'tpu'
INFO:2024-06-25 00:34:55,960:jax._src.xla_bridge:889: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 463521. Not attempting to load libtpu.so in this process.
WARNING:2024-06-25 00:34:55,961:jax._src.xla_bridge:940: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
[CpuDevice(id=0)]
INFO:2024-06-25 00:34:55,970:jax._src.xla_bridge:889: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 482056. Not attempting to load libtpu.so in this process.
WARNING:2024-06-25 00:34:55,971:jax._src.xla_bridge:940: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
[CpuDevice(id=0)]
DEBUG:2024-06-25 00:34:55,982:jax._src.xla_bridge:981: Backend 'cpu' initialized
DEBUG:2024-06-25 00:34:55,982:jax._src.xla_bridge:969: Initializing backend 'cuda'
INFO:2024-06-25 00:34:55,982:jax._src.xla_bridge:889: Unable to initialize backend 'cuda': 
DEBUG:2024-06-25 00:34:55,982:jax._src.xla_bridge:969: Initializing backend 'rocm'
INFO:2024-06-25 00:34:55,982:jax._src.xla_bridge:889: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:2024-06-25 00:34:55,982:jax._src.xla_bridge:969: Initializing backend 'tpu'
INFO:2024-06-25 00:34:56,015:jax._src.xla_bridge:889: Unable to initialize backend 'tpu': ABORTED: The TPU is already in use by process with pid 149052. Not attempting to load libtpu.so in this process.
WARNING:2024-06-25 00:34:56,016:jax._src.xla_bridge:940: A Google TPU may be present on this machine, but either a TPU-enabled jaxlib or libtpu is not installed. Falling back to cpu.
[CpuDevice(id=0)]
ayaka14732 commented 1 week ago

From the logs I realised the actual reason is that the TPU is used by another process. It works after the process is killed.

skye commented 1 week ago

Ah. This is supposed to be raised as an exception instead of falling back to CPU. That functionality must have regressed. Now to figure out why...