jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

JAX and TORCH #18032

Closed ywsslr closed 6 months ago

ywsslr commented 1 year ago

Description

When I only pip the latesd jax with cuda(pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html), I can use the jax with gpu.But when I pip the torch(pip install torch) later, Ican't use the jax with gpu,it remind me that cuda or cusolver's version is older than jax's.Why? Can Older jax version avoid it?Then how can I pip the jax[cuda] with relevant version?

What jax/jaxlib version are you using?

jax-0.4.18 jaxlib-0.4.18+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info

3.10.9/Linux

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

That's correct. The current releases of PyTorch and JAX have incompatible CUDA version dependencies.

I reported this issue to the PyTorch developers a while back, but there has been no interest in relaxing their CUDA version dependencies.

My recommendations:

Does that resolve your problem?

Hope that helps!

adam-hartshorne commented 1 year ago

This is quite annoying (and inconvenient) now that people have written torch2jax functionality which allows GPU-accelerated interaction,

https://github.com/samuela/torch2jax https://github.com/rdyro/torch2jax

flferretti commented 1 year ago

Hi @ywsslr, I've been experimenting the simultaneous usage of Torch and JAX for a while. I'm currently working in a Docker container in which they both work on GPU.

JAX was installed according to the official documentation as:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I leave here the Conda YAML of the environment, there will probably be some extra packages, but I hope this can help:

conda environment ```yaml name: base channels: - nvidia - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=conda_forge - _openmp_mutex=4.5=2_gnu - absl-py=1.4.0=py310h06a4308_0 - alsa-lib=1.2.10=hd590300_0 - appdirs=1.4.4=pyhd3eb1b0_0 - asttokens=2.0.5=pyhd3eb1b0_0 - attr=2.5.1=h166bdaf_1 - backcall=0.2.0=pyhd3eb1b0_0 - binutils=2.40=hdd6e379_0 - binutils_impl_linux-64=2.40=hf600244_0 - binutils_linux-64=2.40=hbdbef99_2 - blas=1.0=openblas - boltons=23.0.0=pyhd8ed1ab_0 - brotli=1.0.9=he6710b0_2 - brotli-python=1.1.0=py310hc6cd4ac_0 - bzip2=1.0.8=h7f98852_4 - c-ares=1.19.1=hd590300_0 - c-compiler=1.6.0=hd590300_0 - ca-certificates=2023.7.22=hbcca054_0 - cairo=1.16.0=hb05425b_5 - certifi=2023.7.22=pyhd8ed1ab_0 - cffi=1.15.1=py310h255011f_3 - charset-normalizer=3.2.0=pyhd8ed1ab_0 - chex=0.1.5=py310h06a4308_0 - click=8.0.4=py310h06a4308_0 - colorama=0.4.6=pyhd8ed1ab_0 - coloredlogs=15.0.1=py310h06a4308_1 - compilers=1.6.0=ha770c72_0 - conda=23.3.1=py310hff52083_0 - conda-package-handling=2.2.0=pyh38be061_0 - conda-package-streaming=0.9.0=pyhd8ed1ab_0 - contourpy=1.0.5=py310hdb19cb5_0 - cryptography=41.0.3=py310h75e40e8_0 - cuda-nvcc=11.3.58=h2467b9f_0 - cuda-version=11.8=h70ddcb2_2 - cudatoolkit=11.8.0=h4ba93d1_12 - cudnn=8.9.2.26=cuda11_0 - cupti=11.8.0=he078b1a_0 - cxx-compiler=1.6.0=h00ab1b0_0 - cycler=0.11.0=pyhd3eb1b0_0 - dbus=1.13.18=hb2f20db_0 - deap=1.4.1=py310h7cbd5c2_0 - decorator=5.1.1=pyhd3eb1b0_0 - dm-tree=0.1.7=py310h6a678d5_1 - docker-pycreds=0.4.0=pyhd3eb1b0_0 - docstring_parser=0.15=pyhd8ed1ab_0 - exceptiongroup=1.0.4=py310h06a4308_0 - executing=0.8.3=pyhd3eb1b0_0 - expat=2.5.0=h6a678d5_0 - filelock=3.9.0=py310h06a4308_0 - flax=0.6.1=pyhd8ed1ab_1 - fmt=9.1.0=h924138e_0 - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0 - font-ttf-inconsolata=2.001=hcb22688_0 - font-ttf-source-code-pro=2.030=hd3eb1b0_0 - font-ttf-ubuntu=0.83=h8b1ccd4_0 - fontconfig=2.14.2=h14ed4e7_0 - fonts-anaconda=1=h8fa9717_0 - fonts-conda-ecosystem=1=hd3eb1b0_0 - fonttools=4.25.0=pyhd3eb1b0_0 - fortran-compiler=1.6.0=heb67821_0 - freetype=2.12.1=h4a9f257_0 - frozendict=2.3.8=py310h2372a71_0 - gcc=12.3.0=h8d2909c_2 - gcc_impl_linux-64=12.3.0=he2b93b0_1 - gcc_linux-64=12.3.0=h76fc315_2 - gettext=0.21.1=h27087fc_0 - gfortran=12.3.0=h499e0f7_2 - gfortran_impl_linux-64=12.3.0=hfcedea8_1 - gfortran_linux-64=12.3.0=h7fe76b4_2 - gitdb=4.0.7=pyhd3eb1b0_0 - gitpython=3.1.30=py310h06a4308_0 - glib=2.78.0=hfc55251_0 - glib-tools=2.78.0=hfc55251_0 - gmp=6.2.1=h295c915_3 - gmpy2=2.1.2=py310heeb90bb_0 - graphite2=1.3.14=h295c915_1 - gst-plugins-base=1.22.5=h8e1006c_1 - gstreamer=1.22.5=h98fc4e7_1 - gxx=12.3.0=h8d2909c_2 - gxx_impl_linux-64=12.3.0=he2b93b0_1 - gxx_linux-64=12.3.0=h8a814eb_2 - harfbuzz=8.2.0=h3d44ed6_0 - humanfriendly=10.0=py310h06a4308_1 - icu=73.2=h59595ed_0 - idna=3.4=pyhd8ed1ab_0 - intel-openmp=2023.1.0=hdb19cb5_46305 - ipython=8.15.0=py310h06a4308_0 - jax-dataclasses=1.5.1=pyhd8ed1ab_0 - jaxlie=1.3.3=pyhd8ed1ab_0 - jedi=0.18.1=py310h06a4308_1 - jinja2=3.1.2=py310h06a4308_0 - jsonpatch=1.32=pyhd8ed1ab_0 - jsonpointer=2.4=py310hff52083_0 - kernel-headers_linux-64=2.6.32=he073ed8_16 - keyutils=1.6.1=h166bdaf_0 - kiwisolver=1.4.4=py310h6a678d5_0 - krb5=1.21.2=h659d440_0 - lame=3.100=h7b6447c_0 - lcms2=2.15=h7f713cb_2 - ld_impl_linux-64=2.40=h41732ed_0 - lerc=4.0.0=h27087fc_0 - libarchive=3.6.2=h039dbb9_1 - libcap=2.69=h0f662aa_0 - libclang=15.0.7=default_h7634d5b_3 - libclang13=15.0.7=default_h9986a30_3 - libcups=2.3.3=h4637d8d_4 - libcurl=8.3.0=hca28451_0 - libdeflate=1.19=hd590300_0 - libedit=3.1.20191231=he28a2e2_2 - libev=4.33=h516909a_1 - libevent=2.1.12=hdbd6064_1 - libexpat=2.5.0=hcb278e6_1 - libffi=3.4.2=h7f98852_5 - libflac=1.4.3=h59595ed_0 - libgcc-devel_linux-64=12.3.0=h8bca6fd_1 - libgcc-ng=13.2.0=h807b86a_0 - libgcrypt=1.10.1=h166bdaf_0 - libgfortran-ng=13.2.0=h69a702a_1 - libgfortran5=13.2.0=ha4646dd_1 - libglib=2.78.0=hebfc3b9_0 - libgomp=13.2.0=h807b86a_0 - libgpg-error=1.47=h71f35ed_0 - libiconv=1.17=h166bdaf_0 - libjpeg-turbo=2.1.5.1=hd590300_1 - libllvm15=15.0.7=h5cf9203_3 - libmamba=1.2.0=hcea66bb_0 - libmambapy=1.2.0=py310h1428755_0 - libnghttp2=1.52.0=h61bc06f_0 - libnsl=2.0.0=h7f98852_0 - libogg=1.3.5=h27cfd23_1 - libopenblas=0.3.21=h043d6bf_0 - libopus=1.3.1=h7b6447c_0 - libpng=1.6.39=h5eee18b_0 - libpq=15.4=hfc447b1_0 - libprotobuf=3.20.3=he621ea3_0 - libsanitizer=12.3.0=h0f45ef3_1 - libsndfile=1.2.2=hbc2eb40_0 - libsolv=0.7.24=hfc55251_4 - libsqlite=3.43.0=h2797004_0 - libssh2=1.11.0=h0841786_0 - libstdcxx-devel_linux-64=12.3.0=h8bca6fd_1 - libstdcxx-ng=13.2.0=h7e041cc_0 - libsystemd0=254=h3516f8a_0 - libtiff=4.6.0=h29866fb_1 - libuuid=2.38.1=h0b41bf4_0 - libvorbis=1.3.7=h7b6447c_0 - libwebp-base=1.3.2=h5eee18b_0 - libxcb=1.15=h7f8727e_0 - libxkbcommon=1.5.0=h5d7e998_3 - libxml2=2.11.5=h232c23b_1 - libzlib=1.2.13=hd590300_5 - lz4-c=1.9.4=hcb278e6_0 - lzo=2.10=h516909a_1000 - magma=2.7.1=h2c23e93_0 - mamba=1.2.0=py310h51d5547_0 - markdown-it-py=2.2.0=py310h06a4308_1 - markupsafe=2.1.1=py310h7f8727e_0 - mashumaro=3.6=py310h06a4308_0 - matplotlib=3.7.2=py310h06a4308_0 - matplotlib-base=3.7.2=py310h1128e8f_0 - matplotlib-inline=0.1.6=py310h06a4308_0 - mdurl=0.1.0=py310h06a4308_0 - mkl=2023.1.0=h213fc3f_46343 - mpc=1.1.0=h10f8cd9_1 - mpfr=4.0.2=hb69a4c5_1 - mpg123=1.31.3=hcb278e6_0 - mpmath=1.3.0=py310h06a4308_0 - msgpack-python=1.0.3=py310hd09550d_0 - munkres=1.1.4=py_0 - mysql-common=8.0.33=hf1915f5_4 - mysql-libs=8.0.33=hca2cd23_4 - ncurses=6.4=hcb278e6_0 - networkx=3.1=py310h06a4308_0 - ninja=1.10.2=h06a4308_5 - ninja-base=1.10.2=hd09550d_5 - nspr=4.35=h6a678d5_0 - nss=3.92=h1d7d5a4_0 - numpy=1.25.2=py310heeff2f4_0 - numpy-base=1.25.2=py310h8a23956_0 - openjpeg=2.5.0=h488ebb8_3 - openssl=3.1.2=hd590300_0 - opt_einsum=3.3.0=pyhd3eb1b0_1 - optax=0.1.4=py310h06a4308_0 - overrides=7.4.0=pyhd8ed1ab_0 - packaging=23.1=pyhd8ed1ab_0 - parso=0.8.3=pyhd3eb1b0_0 - pathtools=0.1.2=pyhd3eb1b0_1 - pcre2=10.40=hc3806b6_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=10.0.1=py310h29da1c1_0 - pip=23.2.1=pyhd8ed1ab_0 - pixman=0.40.0=h7f8727e_1 - pluggy=1.3.0=pyhd8ed1ab_0 - ply=3.11=py310h06a4308_0 - pptree=3.1=pyhd8ed1ab_0 - prompt-toolkit=3.0.36=py310h06a4308_0 - protobuf=3.20.3=py310h6a678d5_0 - psutil=5.9.0=py310h5eee18b_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pulseaudio-client=16.1=hb77b528_5 - pure_eval=0.2.2=pyhd3eb1b0_0 - pybind11-abi=4=hd8ed1ab_3 - pycosat=0.6.4=py310h5764c6d_1 - pycparser=2.21=pyhd8ed1ab_0 - pygments=2.15.1=py310h06a4308_1 - pyopenssl=23.2.0=pyhd8ed1ab_1 - pyparsing=3.0.9=py310h06a4308_0 - pyqt=5.15.9=py310h04931ad_4 - pyqt5-sip=12.12.2=py310hc6cd4ac_4 - pysocks=1.7.1=pyha2e5f31_6 - python=3.10.8=h4a9ceb5_0_cpython - python-dateutil=2.8.2=pyhd3eb1b0_0 - python_abi=3.10=3_cp310 - pytorch=2.0.1=gpu_cuda118py310h7799f5a_0 - pyyaml=6.0=py310h5eee18b_1 - qt-main=5.15.8=hc47bfe8_16 - readline=8.2=h8228510_1 - reproc=14.2.4=h0b41bf4_0 - reproc-cpp=14.2.4=hcb278e6_0 - requests=2.31.0=pyhd8ed1ab_0 - rich=13.3.5=py310h06a4308_0 - ruamel.yaml=0.17.32=py310h2372a71_0 - ruamel.yaml.clib=0.2.7=py310h1fa729e_1 - scipy=1.11.1=py310heeff2f4_0 - sentry-sdk=1.9.0=py310h06a4308_0 - setproctitle=1.2.2=py310h7f8727e_0 - setuptools=68.2.2=pyhd8ed1ab_0 - shtab=1.6.4=pyhd8ed1ab_1 - sip=6.7.11=py310hc6cd4ac_0 - six=1.16.0=pyhd3eb1b0_1 - smmap=4.0.0=pyhd3eb1b0_0 - stack_data=0.2.0=pyhd3eb1b0_0 - sympy=1.11.1=py310h06a4308_0 - sysroot_linux-64=2.12=he073ed8_16 - tbb=2021.8.0=hdb19cb5_0 - tk=8.6.12=h27826a3_0 - toml=0.10.2=pyhd3eb1b0_0 - tomli=2.0.1=py310h06a4308_0 - toolz=0.12.0=pyhd8ed1ab_0 - tornado=6.3.2=py310h5eee18b_0 - tqdm=4.66.1=pyhd8ed1ab_0 - traitlets=5.7.1=py310h06a4308_0 - typing-extensions=4.7.1=py310h06a4308_0 - typing_extensions=4.7.1=py310h06a4308_0 - typing_utils=0.1.0=pyhd8ed1ab_0 - tyro=0.5.7=pyhd8ed1ab_0 - tzdata=2023c=h71feb2d_0 - urllib3=2.0.4=pyhd8ed1ab_0 - wandb=0.15.10=pyhd8ed1ab_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.41.2=pyhd8ed1ab_0 - xcb-util=0.4.0=hd590300_1 - xcb-util-image=0.4.0=h8ee46fc_1 - xcb-util-keysyms=0.4.0=h8ee46fc_1 - xcb-util-renderutil=0.3.9=hd590300_1 - xcb-util-wm=0.4.1=h8ee46fc_1 - xkeyboard-config=2.39=hd590300_0 - xorg-kbproto=1.0.7=h7f98852_1002 - xorg-libice=1.1.1=hd590300_0 - xorg-libsm=1.2.4=h7391055_0 - xorg-libx11=1.8.6=h8ee46fc_0 - xorg-libxau=1.0.11=hd590300_0 - xorg-libxext=1.3.4=h0b41bf4_2 - xorg-libxrender=0.9.11=hd590300_0 - xorg-renderproto=0.11.1=h7f98852_1002 - xorg-xextproto=7.3.0=h0b41bf4_1003 - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002 - xorg-xproto=7.0.31=h27cfd23_1007 - xz=5.2.6=h166bdaf_0 - yaml=0.2.5=h7b6447c_0 - yaml-cpp=0.7.0=h27087fc_2 - zlib=1.2.13=hd590300_5 - zstandard=0.19.0=py310h5764c6d_0 - zstd=1.5.5=hfc55251_0 - pip: - jax==0.4.18 - jaxlib==0.4.18+cuda12.cudnn89 - ml-dtypes==0.3.1 - nvidia-cublas-cu12==12.2.5.6 - nvidia-cuda-cupti-cu12==12.2.142 - nvidia-cuda-nvcc-cu12==12.2.140 - nvidia-cuda-nvrtc-cu12==12.2.140 - nvidia-cuda-runtime-cu12==12.2.140 - nvidia-cudnn-cu12==8.9.4.25 - nvidia-cufft-cu12==11.0.8.103 - nvidia-cusolver-cu12==11.5.2.141 - nvidia-cusparse-cu12==12.1.2.141 - nvidia-nccl-cu12==2.18.3 - nvidia-nvjitlink-cu12==12.2.140 prefix: /conda ```
ywsslr commented 1 year ago

Thank you for your all help. For some reason I can't experience it now,but I'll try it soon and reply you.

aldopareja commented 1 year ago

ok people, this has been a 1 day nightmare. But finally got this to work on an H100 machine with cuda 12.2, without sudo.

then install pytorch from source as that post says!!!! and bualaaa

hawkinsp commented 11 months ago

No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.

It's not a guarantee, though; it might happen that for some JAX and Pytorch versions there's no intersecting CUDA version.

pearu commented 10 months ago

I hit a similar issue when installing pytorch and jax into the same conda environment: when torch is loaded first, jax.devices() will list only cpu devices.

A short summary of diagnosis: It turns out that torch is built against cudnn version 8.7 while jaxlib is built against cudnn version 8.8 leading to an exception when executing jax._src.xla_bridge._check_cuda_versions().

Here follows a reproducer:

mamba create -n test-pytorch-jax pytorch pytorch-cuda=11.8 jaxlib=*=*cuda118* jax -c pytorch -c nvidia --no-channel-priority -y
mamba activate test-pytorch-jax

(note: using strict channel priority would lead to a mamba solver problem).

Import torch before checking jax.devices:

>>> import torch
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

Import torch after checking jax.devices:

>>> import jax
>>> jax.devices()
[cuda(id=0), cuda(id=1)]
>>> import torch
>>> jax.__version__
'0.4.23'
>>> torch.__version__
'2.1.2'
>>> from torch._C import _cudnn
>>> _cudnn.getCompileVersion()
(8, 7, 0)

Notices that the result of jaxlib.cuda._versions.cudnn_get_version() depends on whether torch was imported before or after calling jaxlib.cuda._versions.cudnn_get_version:

>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> import torch
>>> jaxlib.cuda._versions.cudnn_get_version()
8902

vs

>>> import torch
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8700

that qualifies as an incompatible linkage issue: since libcudnn is dynamically loaded, the result of cudnnGetVersion ought to give the version of loaded library and not of the version of the library that a software was built against. The behavior above suggests that torch was linked with libcudnn statically.

A possible resolution: Note that cuDNN minor releases are backward compatible with applications built against the same or earlier minor release. Hence, as long as jaxlib and torch are built against libcudnn with the same major version (8), the jax version check ought to ignore cudnn minor versions. Here is a patch:

diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 7977f6329..17c14bc5a 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -263,7 +263,7 @@ def _check_cuda_versions():
       cuda_versions.cudnn_build_version,
       # NVIDIA promise both backwards and forwards compatibility for cuDNN patch
       # versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
-      scale_for_comparison=100,
+      scale_for_comparison=1000,
   )
   _version_check("cuFFT", cuda_versions.cufft_get_version,
                  cuda_versions.cufft_build_version,
siddharthab commented 7 months ago

No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.

The latest version pair I could find that were compatible with each other were jax[cuda11-pip,cuda11_pip]==0.4.10 and torch==2.2.1+cu118. The main conflict in later versions for jax is for cudnn, which want >8.8, but torch wants ==8.7.

One way to check this would be:

cat > requirements.in <<EOF
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url=https://download.pytorch.org/whl

jax[cuda11_pip]
torch==2.2.1+cu118
EOF

pip-compile

# Check the contents of requirements.txt.
siddharthab commented 7 months ago

A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch. So basically jax[cuda11_pip] and torch in our requirements file works for us.

lucascolley commented 7 months ago

A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch.

How did you get this to work? I'm using conda, but after installing pytorch-cuda=12.1 I get the following error from JAX:

E   RuntimeError: Unable to initialize backend 'cuda': Unable to load CUDA. Is it installed? (set JAX_PLATFORMS='' to automatically choose an available backend)
siddharthab commented 7 months ago

We did not have to do anything special. Just installed the two packages in a clean env, and both worked.

flferretti commented 7 months ago

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)
conda list ```yaml # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge _sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge binutils_impl_linux-64 2.40 hf600244_0 conda-forge binutils_linux-64 2.40 hdade7a5_3 conda-forge blas 2.116 mkl conda-forge blas-devel 3.9.0 16_linux64_mkl conda-forge bzip2 1.0.8 hd590300_5 conda-forge c-ares 1.27.0 hd590300_0 conda-forge ca-certificates 2024.2.2 hbcca054_0 conda-forge cuda-cccl_linux-64 12.1.109 ha770c72_0 conda-forge cuda-cudart 12.1.105 hd3aeb46_0 conda-forge cuda-cudart-dev 12.1.105 hd3aeb46_0 conda-forge cuda-cudart-dev_linux-64 12.1.105 h59595ed_0 conda-forge cuda-cudart-static 12.1.105 hd3aeb46_0 conda-forge cuda-cudart-static_linux-64 12.1.105 h59595ed_0 conda-forge cuda-cudart_linux-64 12.1.105 h59595ed_0 conda-forge cuda-cupti 12.1.105 h59595ed_0 conda-forge cuda-driver-dev_linux-64 12.1.105 h59595ed_0 conda-forge cuda-libraries 12.1.0 0 nvidia cuda-nvcc 12.1.105 hcdd1206_1 conda-forge cuda-nvcc-dev_linux-64 12.1.105 ha770c72_0 conda-forge cuda-nvcc-impl 12.1.105 hd3aeb46_0 conda-forge cuda-nvcc-tools 12.1.105 hd3aeb46_0 conda-forge cuda-nvcc_linux-64 12.1.105 h8a487aa_1 conda-forge cuda-nvrtc 12.1.105 hd3aeb46_0 conda-forge cuda-nvtx 12.1.105 h59595ed_0 conda-forge cuda-opencl 12.1.105 h59595ed_0 conda-forge cuda-runtime 12.1.0 0 nvidia cuda-version 12.1 h1d6eff3_3 conda-forge cudnn 8.9.7.29 h092f7fd_3 conda-forge filelock 3.13.3 pyhd8ed1ab_0 conda-forge gcc_impl_linux-64 12.3.0 he2b93b0_5 conda-forge gcc_linux-64 12.3.0 h6477408_3 conda-forge gxx_impl_linux-64 12.3.0 he2b93b0_5 conda-forge gxx_linux-64 12.3.0 h4a1b8e8_3 conda-forge icu 73.2 h59595ed_0 conda-forge importlib-metadata 7.1.0 pyha770c72_0 conda-forge importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge jax 0.4.25 pyhd8ed1ab_0 conda-forge jaxlib 0.4.23 cuda120py312hc008a70_200 conda-forge jinja2 3.1.3 pyhd8ed1ab_0 conda-forge kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge ld_impl_linux-64 2.40 h41732ed_0 conda-forge libabseil 20240116.1 cxx17_h59595ed_2 conda-forge libblas 3.9.0 16_linux64_mkl conda-forge libcblas 3.9.0 16_linux64_mkl conda-forge libcublas 12.1.0.26 0 nvidia libcufft 11.0.2.4 0 nvidia libcufile 1.6.1.9 hd3aeb46_0 conda-forge libcurand 10.3.2.106 hd3aeb46_0 conda-forge libcusolver 11.4.4.55 0 nvidia libcusparse 12.0.2.55 0 nvidia libexpat 2.6.2 h59595ed_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge libgcc-ng 13.2.0 h807b86a_5 conda-forge libgfortran-ng 13.2.0 h69a702a_5 conda-forge libgfortran5 13.2.0 ha4646dd_5 conda-forge libgomp 13.2.0 h807b86a_5 conda-forge libgrpc 1.62.1 h15f2491_0 conda-forge libhwloc 2.9.3 default_h554bfaf_1009 conda-forge libiconv 1.17 hd590300_2 conda-forge liblapack 3.9.0 16_linux64_mkl conda-forge liblapacke 3.9.0 16_linux64_mkl conda-forge libnpp 12.0.2.50 0 nvidia libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.1.105 hd3aeb46_0 conda-forge libnvjpeg 12.1.1.14 0 nvidia libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libsanitizer 12.3.0 h0f45ef3_5 conda-forge libsqlite 3.45.2 h2797004_0 conda-forge libstdcxx-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libxml2 2.12.6 h232c23b_1 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 15.0.7 h0cdce71_0 conda-forge markupsafe 2.1.5 py312h98912ed_0 conda-forge mkl 2022.1.0 h84fe81f_915 conda-forge mkl-devel 2022.1.0 ha770c72_916 conda-forge mkl-include 2022.1.0 h84fe81f_915 conda-forge ml_dtypes 0.3.2 py312hfb8ada1_0 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge nccl 2.20.5.1 h3a97aeb_0 conda-forge ncurses 6.4.20240210 h59595ed_0 conda-forge networkx 3.2.1 pyhd8ed1ab_0 conda-forge numpy 1.26.4 py312heda63a1_0 conda-forge ocl-icd 2.3.2 hd590300_1 conda-forge openssl 3.2.1 hd590300_1 conda-forge opt-einsum 3.3.0 hd8ed1ab_2 conda-forge opt_einsum 3.3.0 pyhc1e730c_2 conda-forge pip 24.0 pyhd8ed1ab_0 conda-forge python 3.12.2 hab00c5b_0_cpython conda-forge python_abi 3.12 4_cp312 conda-forge pytorch 2.2.1 py3.12_cuda12.1_cudnn8.9.2_0 pytorch pytorch-cuda 12.1 ha16c6d3_5 pytorch pytorch-mutex 1.0 cuda pytorch pyyaml 6.0.1 py312h98912ed_1 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h8228510_1 conda-forge scipy 1.12.0 py312heda63a1_2 conda-forge setuptools 69.2.0 pyhd8ed1ab_0 conda-forge sympy 1.12 pyh04b8f61_3 conda-forge sysroot_linux-64 2.17 h4a8ded7_14 conda-forge tbb 2021.11.0 h00ab1b0_1 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge typing_extensions 4.10.0 pyha770c72_0 conda-forge tzdata 2024a h0c530f3_0 conda-forge wheel 0.43.0 pyhd8ed1ab_0 conda-forge xz 5.2.6 h166bdaf_0 conda-forge yaml 0.2.5 h7f98852_2 conda-forge zipp 3.17.0 pyhd8ed1ab_0 conda-forge ```

fyi @traversaro

traversaro commented 7 months ago

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):

traversaro@IITBMP014LW012:~$ mamba create -n jaxtorchcuda pytorch==2.1.2=*cuda* jaxlib==0.4.23=*cuda* jax cuda-version=12.* python==3.11.* cudatoolkit==12.*

Looking for: ['pytorch==2.1.2[build=*cuda*]', 'jaxlib==0.4.23[build=*cuda*]', 'jax', 'cuda-version=12', 'python=3.11', 'cudatoolkit=12']

conda-forge/linux-64                                        Using cache
conda-forge/noarch                                          Using cache
Could not solve for environment specs
The following packages are incompatible
├─ cuda-version 12**  is installable with the potential options
│  ├─ cuda-version [12.0|12.0.0] would require
│  │  └─ cudatoolkit 12.0|12.0.* , which can be installed;
│  ├─ cuda-version 12.1 would require
│  │  └─ cudatoolkit 12.1|12.1.* , which can be installed;
│  ├─ cuda-version 12.2 would require
│  │  └─ cudatoolkit 12.2|12.2.* , which can be installed;
│  ├─ cuda-version 12.3 would require
│  │  └─ cudatoolkit 12.3|12.3.* , which can be installed;
│  └─ cuda-version 12.4 would require
│     └─ cudatoolkit 12.4|12.4.* , which can be installed;
├─ cudatoolkit 12**  does not exist (perhaps a typo or a missing channel);
├─ jaxlib 0.4.23 *cuda* is installable with the potential options
│  ├─ jaxlib 0.4.23 would require
│  │  └─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│  ├─ jaxlib 0.4.23 would require
│  │  └─ libgrpc >=1.62.1,<1.63.0a0 , which requires
│  │     └─ libprotobuf >=4.25.3,<4.25.4.0a0 , which can be installed;
│  ├─ jaxlib 0.4.23 would require
│  │  ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│  │  └─ libgrpc >=1.59.3,<1.60.0a0 , which requires
│  │     └─ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
│  └─ jaxlib 0.4.23 would require
│     └─ python_abi 3.12.* *_cp312, which requires
│        └─ python 3.12.* *_cpython, which can be installed;
├─ python 3.11**  is not installable because it conflicts with any installable versions previously reported;
└─ pytorch 2.1.2 *cuda* is installable with the potential options
   ├─ pytorch 2.1.2 would require
   │  ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   │  └─ libtorch 2.1.2.*  with the potential options
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_generic_*_0, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_generic_*_1, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_generic_*_3, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_mkl_*_100, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_mkl_*_101, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cpu_mkl_*_103, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cuda112_*_300, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   │     │  └─ pytorch 2.1.2 cuda112_*_301, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
   │     │  ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   │     │  └─ pytorch 2.1.2 cuda118_*_301, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
   │     │  └─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   │     ├─ libtorch 2.1.2 would require
   │     │  └─ pytorch 2.1.2 cuda118_*_300, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  ├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
   │     │  ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   │     │  └─ pytorch 2.1.2 cuda120_*_301, which can be installed;
   │     ├─ libtorch 2.1.2 would require
   │     │  ├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
   │     │  └─ pytorch 2.1.2 cuda120_*_303, which can be installed;
   │     └─ libtorch 2.1.2 would require
   │        └─ pytorch 2.1.2 cuda120_*_300, which can be installed;
   ├─ pytorch 2.1.2 would require
   │  └─ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
   ├─ pytorch 2.1.2 would require
   │  └─ python_abi 3.12.* *_cp312, which can be installed (as previously explained);
   └─ pytorch 2.1.2 would require
      ├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
      └─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported.

Once a conda-forge pytorch version gets compiled with libprotobuf==4.25.3 (i.e. https://github.com/conda-forge/pytorch-cpu-feedstock/pull/228 is ready and merged, big thanks to who the pytorch and jax conda-forge mantainers) it should be possible to install both jax and pytorch with cuda enabled and using cuda 12 just with conda-forge packages.

hawkinsp commented 6 months ago

JAX 0.4.26 relaxed our CUDA version dependencies so the minimum CUDA version for JAX is 12.1. This is a version also supported by PyTorch. Try it out! We're going to try to make sure our supported version range overlaps with at least one PyTorch release.

We dropped support for CUDA 11, note.

traversaro commented 5 months ago

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):

After a bunch of fixes from both jax and pytorch mantainers, now (late May 2024) it is possible to just install jax and pytorch from conda-forge on Linux and out of the box they will work with GPU/CUDA support without the need to use any other conda channel:

$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>>

If for some reason this command does not install the cuda-enabled jax, perhaps you are still using the classic conda solver, in that case you can force the installation of cuda-enabled jax and pytorch with:

conda create -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*

However, this is not necessary if you are using a recent conda install that defaults to use the conda-libmamba-solver, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community .

conda list for reference ~~~ (jaxpytorch) traversaro@IITBMP014LW012:~$ conda list # packages in environment at /home/traversaro/miniforge3/envs/jaxpytorch: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge _sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge binutils_impl_linux-64 2.40 ha1999f0_1 conda-forge binutils_linux-64 2.40 hdade7a5_3 conda-forge bzip2 1.0.8 hd590300_5 conda-forge c-ares 1.28.1 hd590300_0 conda-forge ca-certificates 2024.2.2 hbcca054_0 conda-forge cuda-cccl_linux-64 12.5.39 ha770c72_0 conda-forge cuda-crt-dev_linux-64 12.5.40 ha770c72_0 conda-forge cuda-crt-tools 12.5.40 ha770c72_0 conda-forge cuda-cudart 12.5.39 he02047a_0 conda-forge cuda-cudart-dev 12.5.39 he02047a_0 conda-forge cuda-cudart-dev_linux-64 12.5.39 h85509e4_0 conda-forge cuda-cudart-static 12.5.39 he02047a_0 conda-forge cuda-cudart-static_linux-64 12.5.39 h85509e4_0 conda-forge cuda-cudart_linux-64 12.5.39 h85509e4_0 conda-forge cuda-cupti 12.5.39 he02047a_0 conda-forge cuda-driver-dev_linux-64 12.5.39 h85509e4_0 conda-forge cuda-nvcc 12.5.40 hcdd1206_0 conda-forge cuda-nvcc-dev_linux-64 12.5.40 ha770c72_0 conda-forge cuda-nvcc-impl 12.5.40 hd3aeb46_0 conda-forge cuda-nvcc-tools 12.5.40 hd3aeb46_0 conda-forge cuda-nvcc_linux-64 12.5.40 h8a487aa_0 conda-forge cuda-nvrtc 12.5.40 he02047a_0 conda-forge cuda-nvtx 12.5.39 he02047a_0 conda-forge cuda-nvvm-dev_linux-64 12.5.40 ha770c72_0 conda-forge cuda-nvvm-impl 12.5.40 h59595ed_0 conda-forge cuda-nvvm-tools 12.5.40 h59595ed_0 conda-forge cuda-version 12.5 hd4f0392_3 conda-forge cudnn 8.9.7.29 h092f7fd_3 conda-forge filelock 3.14.0 pyhd8ed1ab_0 conda-forge fsspec 2024.5.0 pyhff2d567_0 conda-forge gcc_impl_linux-64 12.3.0 h58ffeeb_7 conda-forge gcc_linux-64 12.3.0 h6477408_3 conda-forge gmp 6.3.0 h59595ed_1 conda-forge gmpy2 2.1.5 py312h1d5cde6_1 conda-forge gxx_impl_linux-64 12.3.0 h2a574ab_7 conda-forge gxx_linux-64 12.3.0 h4a1b8e8_3 conda-forge icu 73.2 h59595ed_0 conda-forge importlib-metadata 7.1.0 pyha770c72_0 conda-forge importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge jax 0.4.27 pyhd8ed1ab_0 conda-forge jaxlib 0.4.23 cuda120py312h6027bbc_202 conda-forge jinja2 3.1.4 pyhd8ed1ab_0 conda-forge kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge ld_impl_linux-64 2.40 hf3520f5_1 conda-forge libabseil 20240116.2 cxx17_h59595ed_0 conda-forge libblas 3.9.0 22_linux64_openblas conda-forge libcblas 3.9.0 22_linux64_openblas conda-forge libcublas 12.5.2.13 he02047a_0 conda-forge libcufft 11.2.3.18 he02047a_0 conda-forge libcurand 10.3.6.39 he02047a_0 conda-forge libcusolver 11.6.2.40 he02047a_0 conda-forge libcusparse 12.4.1.24 he02047a_0 conda-forge libexpat 2.6.2 h59595ed_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-devel_linux-64 12.3.0 h0223996_107 conda-forge libgcc-ng 13.2.0 h77fa898_7 conda-forge libgfortran-ng 13.2.0 h69a702a_7 conda-forge libgfortran5 13.2.0 hca663fb_7 conda-forge libgomp 13.2.0 h77fa898_7 conda-forge libgrpc 1.62.2 h15f2491_0 conda-forge libhwloc 2.10.0 default_h5622ce7_1001 conda-forge libiconv 1.17 hd590300_2 conda-forge liblapack 3.9.0 22_linux64_openblas conda-forge libmagma 2.7.2 h173bb3b_2 conda-forge libmagma_sparse 2.7.2 h173bb3b_3 conda-forge libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.5.40 he02047a_0 conda-forge libopenblas 0.3.27 pthreads_h413a1c8_0 conda-forge libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libsanitizer 12.3.0 hb8811af_7 conda-forge libsqlite 3.45.3 h2797004_0 conda-forge libstdcxx-devel_linux-64 12.3.0 h0223996_107 conda-forge libstdcxx-ng 13.2.0 hc0a3c3a_7 conda-forge libtorch 2.3.0 cuda120_h2b0da52_301 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libuv 1.48.0 hd590300_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libxml2 2.12.7 hc051c1a_0 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 18.1.6 ha31de31_0 conda-forge markupsafe 2.1.5 py312h98912ed_0 conda-forge mkl 2023.2.0 h84fe81f_50496 conda-forge ml_dtypes 0.4.0 py312h1d6d2e6_1 conda-forge mpc 1.3.1 hfe3b2da_0 conda-forge mpfr 4.2.1 h9458935_1 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge nccl 2.21.5.1 h3a97aeb_0 conda-forge ncurses 6.5 h59595ed_0 conda-forge networkx 3.3 pyhd8ed1ab_1 conda-forge numpy 1.26.4 py312heda63a1_0 conda-forge openssl 3.3.0 h4ab18f5_3 conda-forge opt-einsum 3.3.0 hd8ed1ab_2 conda-forge opt_einsum 3.3.0 pyhc1e730c_2 conda-forge pip 24.0 pyhd8ed1ab_0 conda-forge python 3.12.3 hab00c5b_0_cpython conda-forge python_abi 3.12 4_cp312 conda-forge pytorch 2.3.0 cuda120_py312h26b3cf7_301 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h8228510_1 conda-forge scipy 1.13.1 py312hc2bc53b_0 conda-forge setuptools 70.0.0 pyhd8ed1ab_0 conda-forge sleef 3.5.1 h9b69904_2 conda-forge sympy 1.12 pypyh9d50eac_103 conda-forge sysroot_linux-64 2.17 h4a8ded7_14 conda-forge tbb 2021.12.0 h297d8ca_1 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge typing_extensions 4.11.0 pyha770c72_0 conda-forge tzdata 2024a h0c530f3_0 conda-forge wheel 0.43.0 pyhd8ed1ab_1 conda-forge xz 5.2.6 h166bdaf_0 conda-forge zipp 3.17.0 pyhd8ed1ab_0 conda-forge zstd 1.5.6 ha6fb4c9_0 conda-forge ~~~
varadVaidya commented 3 months ago

Can someone please point out the correct version necessary to get pytorch and jax both with GPU support on CUDA 12 as of July 2024? I would prefer it to be a standard venv rather than a conda env, but either is fine.

traversaro commented 3 months ago

@varadVaidya totally by chance I follow this issue, but in general you may have more success in finding help by using official jax help channels (see https://jax.readthedocs.io/en/latest/beginner_guide.html#finding-help), rather then posting in closed issues.

More on topic, I have no idea about pip/venv with cuda, but for conda the procedure posted in https://github.com/google/jax/issues/18032#issuecomment-2132399059 is working fine for me (when I originally posted the message I forgot to add the -c conda-forge to ensure it works fine also on anaconda or miniconda installation of conda that use defaults instead of conda-forge, I just fixed that to avoid confusion).

@eliseoe @bebark @shaikalthaf4 By change I just noticed that you added a 👎🏽 reaction to my previous comment, any reason for doing so? Just fyi, authors do not get (at least by default) notifications for post reactions.

peterch405 commented 3 months ago

@traversaro I found that running your command with conda will install: jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0 whereas with mamba the correct version is installed: mamba create -c conda-forge -n jaxpytorch pytorch jax jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64 Perhaps this is why you got 3 thumbs down

traversaro commented 3 months ago

@traversaro I found that running your command with conda will install: jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0 whereas with mamba the correct version is installed: mamba create -c conda-forge -n jaxpytorch pytorch jax jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64 Perhaps this is why you got 3 thumbs down

@peterch405

Interestingly, in my system with:

root@DESKTOP-T0NQNLN:~# conda info

     active environment : None
            shell level : 0
       user config file : /root/.condarc
 populated config files : /root/miniforge3/.condarc
                          /root/.condarc
          conda version : 24.3.0
    conda-build version : not installed
         python version : 3.10.14.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=skylake
                          __conda=24.3.0=0
                          __cuda=12.0=0
                          __glibc=2.39=0
                          __linux=5.15.153.1=0
                          __unix=0=0
       base environment : /root/miniforge3  (writable)
      conda av data dir : /root/miniforge3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
          package cache : /root/miniforge3/pkgs
                          /root/.conda/pkgs
       envs directories : /root/miniforge3/envs
                          /root/.conda/envs
               platform : linux-64
             user-agent : conda/24.3.0 requests/2.31.0 CPython/3.10.14 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/24.04 glibc/2.39 solver/libmamba conda-libmamba-solver/24.1.0 libmambapy/1.5.8
                UID:GID : 0:0
             netrc file : None
           offline mode : False

the command

conda create -n conda-forge -n jaxpytorch pytorch jax

installs the cuda jax, but indeed:

conda create --solver=classic -n conda-forge -n jaxpytorch pytorch jax

installs cpu jax. Perhaps you are using an old conda version that is using the classic solver by default? (You can see this if you report the conda info output, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community).

However, even with the classic solver forcing the solver to install the cuda version of jaxlib and pytorch works as expected (even if the classic solver is much slower):

conda create --solver=classic -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*

I edited the original comment accordingly.

peterch405 commented 3 months ago

You are right, I'm using the classic solver:

     active environment : base
    active env location : /home/chovanec/miniconda3
            shell level : 1
       user config file : /home/chovanec/.condarc
 populated config files : /home/chovanec/.condarc
          conda version : 23.1.0
    conda-build version : not installed
         python version : 3.10.8.final.0
       virtual packages : __archspec=1=x86_64
                          __cuda=12.3=0
                          __glibc=2.31=0
                          __linux=5.15.153.1=0
                          __unix=0=0
       base environment : /home/chovanec/miniconda3  (writable)
      conda av data dir : /home/chovanec/miniconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/bioconda/linux-64
                          https://conda.anaconda.org/bioconda/noarch
                          https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
                          https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /home/chovanec/miniconda3/pkgs
                          /home/chovanec/.conda/pkgs
       envs directories : /home/chovanec/miniconda3/envs
                          /home/chovanec/.conda/envs
               platform : linux-64
             user-agent : conda/23.1.0 requests/2.28.1 CPython/3.10.8 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/20.04.6 glibc/2.31
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
shellyzhang2019 commented 3 months ago

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)

conda list fyi @traversaro

Thanks for the solution, however i have found a possible bug that the jax numpy cannot initialize an array which size is bigger than (2, 52, 10) with both jax and jaxlib version are 0.4.30, so i have to downgrade the jax version to 0.4.23 and then works just fine, so for the insurance, the command could be like

conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch

python 3.12 is too newer to some commonly used pkgs

traversaro commented 3 months ago

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

shellyzhang2019 commented 3 months ago

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

I'm not sure, maybe later i can do a test,thx for the noticing

shellyzhang2019 commented 3 months ago

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

sorry for the late reply, here is the outputs image since the jax and jax cuda lib are manually reinstalled by the pypi, i guess yes that the packages are privileged installed from conda-forge :)

traversaro commented 3 months ago

Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.

shellyzhang2019 commented 3 months ago

Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.

In my case, the conflicts comes from the torch and jaxlib stick to different cudnn version, formerly i didn't seek to conda-forge to install the cudatoolkit compatible for both torch and jaxlib. i use the pip command from the official jax documentation btw.

traversaro commented 3 months ago

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

shellyzhang2019 commented 3 months ago

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

traversaro commented 3 months ago

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.

shellyzhang2019 commented 3 months ago

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.

you are right, accidentally i use the pip install, and it just found the cudnn version meets the requirement lol.

lucascolley commented 2 months ago

@traversaro apologies to have to revive this issue, but your solution does not work for me:

$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

The error message suggests I need cudnn>=9.1.0, however the most recent version on conda-forge appears to be 8.9.7.29.

System info:

Operating System: Ubuntu 22.04.4 LTS GPU: NVIDIA GeForce GTX 1060 6GB Graphics Driver: NVIDIA driver metapackage from nvidia-driver-535

mamba list below the fold:

``` # packages in environment at /home/lucas/mambaforge/envs/jaxpytorch: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_kmp_llvm conda-forge _sysroot_linux-64_curr_repodata_hack 3 h69a702a_16 conda-forge binutils_impl_linux-64 2.40 ha1999f0_7 conda-forge binutils_linux-64 2.40 hb3c18ed_0 conda-forge bzip2 1.0.8 h4bc722e_7 conda-forge c-ares 1.32.3 h4bc722e_0 conda-forge ca-certificates 2024.7.4 hbcca054_0 conda-forge cuda-cccl_linux-64 12.5.39 ha770c72_0 conda-forge cuda-crt-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-crt-tools 12.5.82 ha770c72_0 conda-forge cuda-cudart 12.5.82 he02047a_0 conda-forge cuda-cudart-dev 12.5.82 he02047a_0 conda-forge cuda-cudart-dev_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cudart-static 12.5.82 he02047a_0 conda-forge cuda-cudart-static_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cudart_linux-64 12.5.82 h85509e4_0 conda-forge cuda-cupti 12.5.82 he02047a_0 conda-forge cuda-driver-dev_linux-64 12.5.82 h85509e4_0 conda-forge cuda-nvcc 12.5.82 hcdd1206_0 conda-forge cuda-nvcc-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-nvcc-impl 12.5.82 hd3aeb46_0 conda-forge cuda-nvcc-tools 12.5.82 hd3aeb46_0 conda-forge cuda-nvcc_linux-64 12.5.82 h8a487aa_0 conda-forge cuda-nvrtc 12.5.82 he02047a_0 conda-forge cuda-nvtx 12.5.82 he02047a_0 conda-forge cuda-nvvm-dev_linux-64 12.5.82 ha770c72_0 conda-forge cuda-nvvm-impl 12.5.82 h59595ed_0 conda-forge cuda-nvvm-tools 12.5.82 h59595ed_0 conda-forge cuda-version 12.5 hd4f0392_3 conda-forge cudnn 8.9.7.29 h092f7fd_3 conda-forge filelock 3.15.4 pyhd8ed1ab_0 conda-forge fsspec 2024.6.1 pyhff2d567_0 conda-forge gcc_impl_linux-64 13.3.0 hfea6d02_0 conda-forge gcc_linux-64 13.3.0 hc28eda2_0 conda-forge gmp 6.3.0 hac33072_2 conda-forge gmpy2 2.1.5 py312h1d5cde6_1 conda-forge gxx_impl_linux-64 13.3.0 hffce095_0 conda-forge gxx_linux-64 13.3.0 h6834431_0 conda-forge icu 75.1 he02047a_0 conda-forge importlib-metadata 8.2.0 pyha770c72_0 conda-forge importlib_metadata 8.2.0 hd8ed1ab_0 conda-forge jax 0.4.31 pyhd8ed1ab_0 conda-forge jaxlib 0.4.30 cuda120py312h4008524_200 conda-forge jinja2 3.1.4 pyhd8ed1ab_0 conda-forge kernel-headers_linux-64 3.10.0 h4a8ded7_16 conda-forge ld_impl_linux-64 2.40 hf3520f5_7 conda-forge libabseil 20240116.2 cxx17_he02047a_1 conda-forge libblas 3.9.0 23_linux64_openblas conda-forge libcblas 3.9.0 23_linux64_openblas conda-forge libcublas 12.5.3.2 he02047a_0 conda-forge libcufft 11.2.3.61 he02047a_0 conda-forge libcurand 10.3.6.82 he02047a_0 conda-forge libcusolver 11.6.3.83 he02047a_0 conda-forge libcusparse 12.5.1.3 he02047a_0 conda-forge libexpat 2.6.2 h59595ed_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge libgcc-ng 14.1.0 h77fa898_0 conda-forge libgfortran-ng 14.1.0 h69a702a_0 conda-forge libgfortran5 14.1.0 hc5f4f2c_0 conda-forge libgomp 14.1.0 h77fa898_0 conda-forge libgrpc 1.62.2 h15f2491_0 conda-forge libhwloc 2.11.1 default_hecaa2ac_1000 conda-forge libiconv 1.17 hd590300_2 conda-forge liblapack 3.9.0 23_linux64_openblas conda-forge libmagma 2.7.2 h173bb3b_2 conda-forge libmagma_sparse 2.7.2 h173bb3b_3 conda-forge libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.5.82 he02047a_0 conda-forge libopenblas 0.3.27 pthreads_hac2b453_1 conda-forge libprotobuf 4.25.3 h08a7969_0 conda-forge libre2-11 2023.09.01 h5a48ba9_2 conda-forge libsanitizer 13.3.0 heb74ff8_0 conda-forge libsqlite 3.46.0 hde9e2c9_0 conda-forge libstdcxx-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge libstdcxx-ng 14.1.0 hc0a3c3a_0 conda-forge libtorch 2.3.1 cuda120_h2b0da52_300 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libuv 1.48.0 hd590300_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libxml2 2.12.7 he7c6b58_4 conda-forge libzlib 1.3.1 h4ab18f5_1 conda-forge llvm-openmp 18.1.8 hf5423f3_0 conda-forge markupsafe 2.1.5 py312h98912ed_0 conda-forge mkl 2023.2.0 h84fe81f_50496 conda-forge ml_dtypes 0.4.0 py312h1d6d2e6_1 conda-forge mpc 1.3.1 hfe3b2da_0 conda-forge mpfr 4.2.1 h38ae2d0_2 conda-forge mpmath 1.3.0 pyhd8ed1ab_0 conda-forge nccl 2.22.3.1 hbc370b7_1 conda-forge ncurses 6.5 h59595ed_0 conda-forge networkx 3.3 pyhd8ed1ab_1 conda-forge numpy 2.0.1 py312h1103770_0 conda-forge openssl 3.3.1 h4bc722e_2 conda-forge opt-einsum 3.3.0 hd8ed1ab_2 conda-forge opt_einsum 3.3.0 pyhc1e730c_2 conda-forge pip 24.0 pyhd8ed1ab_0 conda-forge python 3.12.4 h194c7f8_0_cpython conda-forge python_abi 3.12 4_cp312 conda-forge pytorch 2.3.1 cuda120_py312h26b3cf7_300 conda-forge re2 2023.09.01 h7f4b329_2 conda-forge readline 8.2 h8228510_1 conda-forge scipy 1.14.0 py312hc2bc53b_1 conda-forge setuptools 71.0.4 pyhd8ed1ab_0 conda-forge sleef 3.6.1 h3400bea_1 conda-forge sympy 1.13.0 pypyh2585a3b_103 conda-forge sysroot_linux-64 2.17 h4a8ded7_16 conda-forge tbb 2021.12.0 h434a139_3 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge typing_extensions 4.12.2 pyha770c72_0 conda-forge tzdata 2024a h0c530f3_0 conda-forge wheel 0.43.0 pyhd8ed1ab_1 conda-forge xz 5.2.6 h166bdaf_0 conda-forge zipp 3.19.2 pyhd8ed1ab_0 conda-forge zstd 1.5.6 ha6fb4c9_0 conda-forge ```

UPDATE: jax=0.4.28 ~appears to~ does work, so this looks like a bug introduced recently in JAX. cc @hawkinsp

traversaro commented 2 months ago

@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!

traversaro commented 2 months ago

@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!

Thanks @lucascolley, indeed it seems a regression in the conda-forge jax package 0.4.31, I opened https://github.com/conda-forge/jaxlib-feedstock/issues/277 to track the problem.