Closed misterguick closed 1 day ago
Hello - you may have better luck asking this question over in the JAX repo - but my general understanding is that the jax wheel ships with its own CUDA binaries and that it prefers to use those - it's probably using a newer ptxas than the one in your path that reported being 12.4.
If getting rid of this warning is important, you can either update your driver, or there's probably a way to get JAX to use your system copy of CUDA/CUDNN.
Hi all,
I work on a cluster with cuda 12.4. I'm trying to use brax through torchRL wrapper. Among many issues I keep running across some weird compatibility warning. Here is the setup:
import brax.envs ; base_env = brax.envs.get_environment("halfcheetah")
which gives
While this is not a blocking issue this remains very annoying to me.
If I run nvidia-smi I get
If I run I get
There seems to be a missmatch between the warning and this last step ...
Here are the versions of the related packages:
I tried reinstalling everything (jax, brax, cuda) from pip and conda but this never removed the warning.
Thank you very much in advance !