google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Jax not recognizing GPU. #21240

Open charlie-guan opened 4 months ago

charlie-guan commented 4 months ago

Description

I am trying to reproduce the study of this work from Google DeepMind by running Jax on NVIDIA GPU (Driver: 550.67) and CUDA (12.4), but it returns

"No GPU/TPU found, falling back to CPU."

I tried bumping up the jax and jaxlib versiosn to 0.4.28 (the latest version) from 0.4.16 (the version listed in requirements.txt) and also upgraded flax to 0.8.3 from 0.7.4. These changes eliminate the warning and Jax seems to recognize the GPU devices but the computation is still very slow. How do I fix this?

System info (python version, jaxlib version, accelerator, etc.)

$ python3 -c "import jax; jax.print_environment_info()" jax: 0.4.28 jaxlib: 0.4.28 numpy: 1.26.0 python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] jax.devices (10 total, 10 local): [cuda(id=0) cuda(id=1) ... cuda(id=8) cuda(id=9)] process_count: 1 platform: uname_result(system='Linux', node='nebula', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')

$ nvidia-smi Wed May 15 09:53:43 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA RTX A6000 Off | 00000000:1A:00.0 Off | Off | | 30% 37C P2 73W / 300W | 37062MiB / 49140MiB | 17% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA RTX A6000 Off | 00000000:1B:00.0 Off | Off | | 30% 50C P2 135W / 300W | 39244MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 2 NVIDIA RTX A6000 Off | 00000000:1C:00.0 Off | Off | | 30% 39C P2 80W / 300W | 16274MiB / 49140MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 3 NVIDIA RTX A6000 Off | 00000000:1D:00.0 Off | Off | | 30% 56C P2 146W / 300W | 48601MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 4 NVIDIA RTX A6000 Off | 00000000:1E:00.0 Off | Off | | 30% 57C P2 149W / 300W | 48037MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 5 NVIDIA RTX A6000 Off | 00000000:3D:00.0 Off | Off | | 30% 49C P2 120W / 300W | 46176MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 6 NVIDIA RTX A6000 Off | 00000000:3E:00.0 Off | Off | | 30% 52C P2 158W / 300W | 47214MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 7 NVIDIA RTX A6000 Off | 00000000:3F:00.0 Off | Off | | 30% 57C P2 117W / 300W | 47026MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 8 NVIDIA RTX A6000 Off | 00000000:40:00.0 Off | Off | | 30% 56C P2 159W / 300W | 32440MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 9 NVIDIA RTX A6000 Off | 00000000:41:00.0 Off | Off | | 30% 53C P2 123W / 300W | 47000MiB / 49140MiB | 100% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | 0 N/A N/A 185733 C python3 36788MiB | | 0 N/A N/A 191989 C python3 262MiB | | 1 N/A N/A 190770 C ...liu/anaconda3/envs/gemma/bin/python 38972MiB | | 1 N/A N/A 191989 C python3 262MiB | | 2 N/A N/A 176896 C python3 16000MiB | | 2 N/A N/A 191989 C python3 262MiB | | 3 N/A N/A 176896 C python3 262MiB | | 3 N/A N/A 190769 C ...liu/anaconda3/envs/gemma/bin/python 48062MiB | | 3 N/A N/A 191989 C python3 262MiB | | 4 N/A N/A 176896 C python3 262MiB | | 4 N/A N/A 190772 C ...liu/anaconda3/envs/gemma/bin/python 47498MiB | | 4 N/A N/A 191989 C python3 262MiB | | 5 N/A N/A 190771 C ...liu/anaconda3/envs/gemma/bin/python 45904MiB | | 5 N/A N/A 191989 C python3 262MiB | | 6 N/A N/A 190773 C ...liu/anaconda3/envs/gemma/bin/python 46942MiB | | 6 N/A N/A 191989 C python3 262MiB | | 7 N/A N/A 190774 C ...liu/anaconda3/envs/gemma/bin/python 46754MiB | | 7 N/A N/A 191989 C python3 262MiB | | 8 N/A N/A 190767 C ...liu/anaconda3/envs/gemma/bin/python 32360MiB | | 8 N/A N/A 191989 C python3 262MiB | | 9 N/A N/A 190768 C ...liu/anaconda3/envs/gemma/bin/python 46728MiB | | 9 N/A N/A 191989 C python3 262MiB | +-----------------------------------------------------------------------------------------+

hawkinsp commented 4 months ago

Hmm. I'm not sure this is an actionable report. You upgraded and the original problem was fixed, it seems?

You say the model is running slowly. Can you say more? Knowing nothing about that particular model are you seeing different performance characteristics to the original model authors? How so?

charlie-guan commented 4 months ago

I resolved this issue.

rajasekharporeddy commented 1 week ago

Hi @charlie-guan

Please feel free to close the issue, if resolved.

Thank you.