Open Aki1991 opened 1 month ago
v0.4.23 is fairly old; I'd suggest trying with a more recent JAX version, particularly if you are using more recent CUDA versions.
I tried this specific version because it gives error
jax.random has no attribute PRNGKeyArray
jax.random.PRNGKeyArray
was deprecated in JAX v0.4.16 and removed in JAX v0.4.24 (see the Change log). You can replace it with jax.Array
to address this error, and then hopefully use a more recent JAX version.
I changed the "jax.random.PRNGKeyArray" with "jax.Array", which solved the error of jax.random has no attribute PRNGKeyArray.
I'd suggest trying with a more recent JAX version
I tried with v0.4.31, but it is still giving NONE as output for jax.jit(function).lower(x, y).cost_analysis()
.
UPDATE: I tried with all the versions from 0.4.23 to 0.4.30. Same results.
Can there be any problem with cuda configuration? This is my cuda configuration in conda environment:
cuda 12.4.0 0 nvidia
cuda-cccl 12.5.39 0 nvidia
cuda-cccl_linux-64 12.5.39 0 nvidia
cuda-command-line-tools 12.5.1 0 nvidia
cuda-compiler 12.5.1 0 nvidia
cuda-cudart 12.5.82 0 nvidia
cuda-cudart-dev 12.5.82 0 nvidia
cuda-cudart-dev_linux-64 12.5.82 0 nvidia
cuda-cudart-static 12.5.82 0 nvidia
cuda-cudart-static_linux-64 12.5.82 0 nvidia
cuda-cudart_linux-64 12.5.82 0 nvidia
cuda-cuobjdump 12.5.39 0 nvidia
cuda-cupti 12.5.82 0 nvidia
cuda-cupti-dev 12.5.82 0 nvidia
cuda-cuxxfilt 12.5.82 0 nvidia
cuda-demo-suite 12.4.127 0 nvidia
cuda-driver-dev 12.5.82 0 nvidia
cuda-driver-dev_linux-64 12.5.82 0 nvidia
cuda-gdb 12.5.82 0 nvidia
cuda-libraries 12.5.1 0 nvidia
cuda-libraries-dev 12.5.1 0 nvidia
cuda-nsight 12.5.82 0 nvidia
cuda-nvcc 12.4.131 0 nvidia
cuda-nvdisasm 12.5.39 0 nvidia
cuda-nvml-dev 12.5.82 0 nvidia
cuda-nvprof 12.5.82 0 nvidia
cuda-nvprune 12.5.82 0 nvidia
cuda-nvrtc 12.5.82 0 nvidia
cuda-nvrtc-dev 12.5.82 0 nvidia
cuda-nvtx 12.5.82 0 nvidia
cuda-nvvp 12.5.82 0 nvidia
cuda-opencl 12.5.39 0 nvidia
cuda-opencl-dev 12.5.39 0 nvidia
cuda-profiler-api 12.5.39 0 nvidia
cuda-runtime 12.5.1 0 nvidia
cuda-sanitizer-api 12.5.81 0 nvidia
cuda-toolkit 12.5.1 0 nvidia
cuda-tools 12.5.1 0 nvidia
cuda-version 12.5 3 nvidia
cuda-visual-tools 12.5.1 0 nvidia
jax-cuda12-pjrt 0.4.31 pypi_0 pypi
jax-cuda12-plugin 0.4.31 pypi_0 pypi
nvidia-cuda-cupti-cu12 12.3.101 pypi_0 pypi
nvidia-cuda-nvcc-cu12 12.3.107 pypi_0 pypi
nvidia-cuda-nvrtc-cu12 12.3.107 pypi_0 pypi
nvidia-cuda-runtime-cu12 12.3.101 pypi_0 pypi
@Aki1991 you may try to run jax.jit(function).lower(x, y).compile().cost_analysis()
to get the cost analysis on GPU.
For example:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
return jnp.sin(jnp.cos(x))
c = f.lower(3.)
print("cost_analysis:", c.compile().cost_analysis())
# cost_analysis: [{'transcendentals': 2.0, 'utilization0{}': 1.0, 'bytes accessed0{}': 4.0, 'bytes accessedout{}': 4.0, 'bytes accessed': 8.0}]
Im also using jax: 0.4.30 jaxlib: 0.4.30
As suggested by @vfdev-5 . I changed
analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).cost_analysis()
to
analysis = jax.jit(flax_model_apply_fn).lower(*dummy_input).compile().cost_analysis()[0]
that fixed the problem for me.
Description
I am training a model Owl_Vit. While trying to train it with GPU, at one stage, it gives None as output for
jax.jit(function).lower(x, y).cost_analysis()
.The thing is that when I use cpu version of the jax, it is working. I get this output:
{'bytes accessed0{}': 4058464512.0, 'utilization0{}': 2059.1572265625, 'bytes accessedout{}': 4285124864.0, 'bytes accessed2{}': 13349412.0, 'utilization2{}': 21.0, 'utilization1{}': 1139.0, 'bytes accessed1{}': 2781101056.0, 'flops': 237087424512.0, 'transcendentals': 111938816.0, 'utilization4{}': 1.0, 'bytes accessed': 11138055168.0, 'bytes accessed3{}': 4.0, 'utilization3{}': 2.0}
So there is no problem with the code. It's just the version of jax that is giving error I think. for version 0.4.23, and 0.4.30 it is giving same error. I tried this specific version because it gives error
jax.random has no attribute PRNGKeyArray
. While using version <0.4.23, it does not give this error. But for both these version, for GPU it is not working.I installed jax with
pip install -U "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
andpip install -U "jax[cuda12]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
. I am using conda environment. I tried with CUDA=12.1 to CUDA 12.5, but it is giving same error. In my base OS I have CUDA=12.4.Can anyone tell me what can be the issue here. Thank you.
System info (python version, jaxlib version, accelerator, etc.)