Closed stergiosba closed 4 months ago
fyi @Giulero @flferretti
By default, the installation of JAX comes with the CPU version of jaxlib
. If you're interested in the GPU version, I'd suggest to install it as per the JAX documentation after having installed ADAM:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Personally I am aware of the installation process for using the GPU with Jax. I just wanted to point this out so that others are not confused when their code is not as fast as they hoped.
Personally I am aware of the installation process for using the GPU with Jax. I just wanted to point this out so that others are not confused when their code is not as fast as they hoped.
Good point, would you be interested (if it is ok for the mantainer @Giulero) to add a sentence in the readme warning about this and perhaps linking to https://github.com/google/jax#instructions or https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-via-pip-easier ?
Thanks @stergiosba for opening the issue! @flferretti is right. Yes @traversaro, I agree!
I guess this is solved. Closing!
As an additional info for anyone finding this via Google, since ~May 2025 the jax
package on conda-forge on Linux amd64 by default installs a working GPU-enabled jax, i.e.:
conda create -n adamjax -c conda-forge adam-robotics jax
on a linux-64
system with a working NVIDIA driver installation install a GPU-power jax out of the box.
The Jax pip installation only comes with the CPU version of Jax. Is this intended?
pip install adam-robotics[jax]