ami-iit / adam

adam implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
https://adam-docs.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
131 stars 20 forks source link

Jax installation for GPU #50

Closed stergiosba closed 4 months ago

stergiosba commented 11 months ago

The Jax pip installation only comes with the CPU version of Jax. Is this intended? pip install adam-robotics[jax]

traversaro commented 11 months ago

fyi @Giulero @flferretti

flferretti commented 11 months ago

By default, the installation of JAX comes with the CPU version of jaxlib. If you're interested in the GPU version, I'd suggest to install it as per the JAX documentation after having installed ADAM:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 
stergiosba commented 11 months ago

Personally I am aware of the installation process for using the GPU with Jax. I just wanted to point this out so that others are not confused when their code is not as fast as they hoped.

traversaro commented 11 months ago

Personally I am aware of the installation process for using the GPU with Jax. I just wanted to point this out so that others are not confused when their code is not as fast as they hoped.

Good point, would you be interested (if it is ok for the mantainer @Giulero) to add a sentence in the readme warning about this and perhaps linking to https://github.com/google/jax#instructions or https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-via-pip-easier ?

Giulero commented 11 months ago

Thanks @stergiosba for opening the issue! @flferretti is right. Yes @traversaro, I agree!

Giulero commented 4 months ago

I guess this is solved. Closing!

traversaro commented 4 months ago

As an additional info for anyone finding this via Google, since ~May 2025 the jax package on conda-forge on Linux amd64 by default installs a working GPU-enabled jax, i.e.:

conda create -n adamjax -c conda-forge adam-robotics jax

on a linux-64 system with a working NVIDIA driver installation install a GPU-power jax out of the box.