google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

ptxas version missmatch #488

Closed misterguick closed 1 day ago

misterguick commented 1 month ago

Hi all,

I work on a cluster with cuda 12.4. I'm trying to use brax through torchRL wrapper. Among many issues I keep running across some weird compatibility warning. Here is the setup:

import brax.envs ; base_env = brax.envs.get_environment("halfcheetah")

which gives

W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.40). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

While this is not a blocking issue this remains very annoying to me.

If I run nvidia-smi I get

NVIDIA-SMI 550.78 Driver Version: 550.78 CUDA Version: 12.4

If I run I get

ptxas: NVIDIA (R) Ptx optimizing assembler Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:14:54_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0

There seems to be a missmatch between the warning and this last step ...

Here are the versions of the related packages:

brax 0.10.4 pypi_0 pypi jax 0.4.28 pypi_0 pypi jax-cuda12-pjrt 0.4.28 pypi_0 pypi jax-cuda12-plugin 0.4.28 pypi_0 pypi jaxlib 0.4.28+cuda12.cudnn89 pypi_0 pypi jaxopt 0.8.3 pypi_0 pypi jaxtyping 0.2.29 pypi_0 pypi

I tried reinstalling everything (jax, brax, cuda) from pip and conda but this never removed the warning.

Thank you very much in advance !

erikfrey commented 1 day ago

Hello - you may have better luck asking this question over in the JAX repo - but my general understanding is that the jax wheel ships with its own CUDA binaries and that it prefers to use those - it's probably using a newer ptxas than the one in your path that reported being 12.4.

If getting rid of this warning is important, you can either update your driver, or there's probably a way to get JAX to use your system copy of CUDA/CUDNN.