google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

jax.jit(function).lower(x, y).cost_analysis() gives NONE while running with GPU as the device #22713

Open Aki1991 opened 1 month ago

Aki1991 commented 1 month ago

Description

I am training a model Owl_Vit. While trying to train it with GPU, at one stage, it gives None as output for jax.jit(function).lower(x, y).cost_analysis().

The thing is that when I use cpu version of the jax, it is working. I get this output: {'bytes accessed0{}': 4058464512.0, 'utilization0{}': 2059.1572265625, 'bytes accessedout{}': 4285124864.0, 'bytes accessed2{}': 13349412.0, 'utilization2{}': 21.0, 'utilization1{}': 1139.0, 'bytes accessed1{}': 2781101056.0, 'flops': 237087424512.0, 'transcendentals': 111938816.0, 'utilization4{}': 1.0, 'bytes accessed': 11138055168.0, 'bytes accessed3{}': 4.0, 'utilization3{}': 2.0}

So there is no problem with the code. It's just the version of jax that is giving error I think. for version 0.4.23, and 0.4.30 it is giving same error. I tried this specific version because it gives error jax.random has no attribute PRNGKeyArray. While using version <0.4.23, it does not give this error. But for both these version, for GPU it is not working.

I installed jax with pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html and pip install -U "jax[cuda12]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html. I am using conda environment. I tried with CUDA=12.1 to CUDA 12.5, but it is giving same error. In my base OS I have CUDA=12.4.

Can anyone tell me what can be the issue here. Thank you.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:24:10) [GCC 9.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='AHMCPU2839', release='5.15.0-113-generic', version='#123~20.04.1-Ubuntu SMP Wed Jun 12 17:33:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Jul 29 17:22:17 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.78                 Driver Version: 550.78         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:21:00.0  On |                  N/A |
| 30%   36C    P2             27W /  350W |     877MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1661      G   /usr/lib/xorg/Xorg                            198MiB |
|    0   N/A  N/A      2996      G   /usr/lib/xorg/Xorg                            281MiB |
|    0   N/A  N/A      3204      G   /usr/bin/gnome-shell                           39MiB |
|    0   N/A  N/A    408669      G   ...ures=SpareRendererForSitePerProcess         20MiB |
|    0   N/A  N/A    454619      G   ...seed-version=20240716-180143.517000         36MiB |
|    0   N/A  N/A   1467281      C   python                                        256MiB |
+-----------------------------------------------------------------------------------------+
jakevdp commented 1 month ago

v0.4.23 is fairly old; I'd suggest trying with a more recent JAX version, particularly if you are using more recent CUDA versions.

I tried this specific version because it gives error jax.random has no attribute PRNGKeyArray

jax.random.PRNGKeyArray was deprecated in JAX v0.4.16 and removed in JAX v0.4.24 (see the Change log). You can replace it with jax.Array to address this error, and then hopefully use a more recent JAX version.

Aki1991 commented 1 month ago

I changed the "jax.random.PRNGKeyArray" with "jax.Array", which solved the error of jax.random has no attribute PRNGKeyArray.

I'd suggest trying with a more recent JAX version

I tried with v0.4.31, but it is still giving NONE as output for jax.jit(function).lower(x, y).cost_analysis(). UPDATE: I tried with all the versions from 0.4.23 to 0.4.30. Same results.

Can there be any problem with cuda configuration? This is my cuda configuration in conda environment:

cuda                      12.4.0                        0    nvidia
cuda-cccl                 12.5.39                       0    nvidia
cuda-cccl_linux-64        12.5.39                       0    nvidia
cuda-command-line-tools   12.5.1                        0    nvidia
cuda-compiler             12.5.1                        0    nvidia
cuda-cudart               12.5.82                       0    nvidia
cuda-cudart-dev           12.5.82                       0    nvidia
cuda-cudart-dev_linux-64  12.5.82                       0    nvidia
cuda-cudart-static        12.5.82                       0    nvidia
cuda-cudart-static_linux-64 12.5.82                       0    nvidia
cuda-cudart_linux-64      12.5.82                       0    nvidia
cuda-cuobjdump            12.5.39                       0    nvidia
cuda-cupti                12.5.82                       0    nvidia
cuda-cupti-dev            12.5.82                       0    nvidia
cuda-cuxxfilt             12.5.82                       0    nvidia
cuda-demo-suite           12.4.127                      0    nvidia
cuda-driver-dev           12.5.82                       0    nvidia
cuda-driver-dev_linux-64  12.5.82                       0    nvidia
cuda-gdb                  12.5.82                       0    nvidia
cuda-libraries            12.5.1                        0    nvidia
cuda-libraries-dev        12.5.1                        0    nvidia
cuda-nsight               12.5.82                       0    nvidia
cuda-nvcc                 12.4.131                      0    nvidia
cuda-nvdisasm             12.5.39                       0    nvidia
cuda-nvml-dev             12.5.82                       0    nvidia
cuda-nvprof               12.5.82                       0    nvidia
cuda-nvprune              12.5.82                       0    nvidia
cuda-nvrtc                12.5.82                       0    nvidia
cuda-nvrtc-dev            12.5.82                       0    nvidia
cuda-nvtx                 12.5.82                       0    nvidia
cuda-nvvp                 12.5.82                       0    nvidia
cuda-opencl               12.5.39                       0    nvidia
cuda-opencl-dev           12.5.39                       0    nvidia
cuda-profiler-api         12.5.39                       0    nvidia
cuda-runtime              12.5.1                        0    nvidia
cuda-sanitizer-api        12.5.81                       0    nvidia
cuda-toolkit              12.5.1                        0    nvidia
cuda-tools                12.5.1                        0    nvidia
cuda-version              12.5                          3    nvidia
cuda-visual-tools         12.5.1                        0    nvidia
jax-cuda12-pjrt           0.4.31                   pypi_0    pypi
jax-cuda12-plugin         0.4.31                   pypi_0    pypi
nvidia-cuda-cupti-cu12    12.3.101                 pypi_0    pypi
nvidia-cuda-nvcc-cu12     12.3.107                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.3.107                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.3.101                 pypi_0    pypi
vfdev-5 commented 1 month ago

@Aki1991 you may try to run jax.jit(function).lower(x, y).compile().cost_analysis() to get the cost analysis on GPU. For example:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
    return jnp.sin(jnp.cos(x))

c = f.lower(3.)
print("cost_analysis:", c.compile().cost_analysis())
# cost_analysis: [{'transcendentals': 2.0, 'utilization0{}': 1.0, 'bytes accessed0{}': 4.0, 'bytes accessedout{}': 4.0, 'bytes accessed': 8.0}]
jmhuer commented 1 week ago

Im also using jax: 0.4.30 jaxlib: 0.4.30

As suggested by @vfdev-5 . I changed

analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).cost_analysis() to

analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).compile().cost_analysis()[0]

that fixed the problem for me.