Closed ywsslr closed 6 months ago
That's correct. The current releases of PyTorch and JAX have incompatible CUDA version dependencies.
I reported this issue to the PyTorch developers a while back, but there has been no interest in relaxing their CUDA version dependencies.
My recommendations:
Does that resolve your problem?
Hope that helps!
This is quite annoying (and inconvenient) now that people have written torch2jax functionality which allows GPU-accelerated interaction,
https://github.com/samuela/torch2jax https://github.com/rdyro/torch2jax
Hi @ywsslr, I've been experimenting the simultaneous usage of Torch and JAX for a while. I'm currently working in a Docker container in which they both work on GPU.
JAX was installed according to the official documentation as:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I leave here the Conda YAML of the environment, there will probably be some extra packages, but I hope this can help:
Thank you for your all help. For some reason I can't experience it now,but I'll try it soon and reply you.
ok people, this has been a 1 day nightmare. But finally got this to work on an H100 machine with cuda 12.2, without sudo.
First install Cuda 12.2 (this was already there for me)
Then install Cudnn 8.9 through the official website, using the tar option: https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
then follow what this guy did to build magma: https://github.com/huggingface/autotrain-advanced/issues/281#issuecomment-1740762360
then install pytorch from source as that post says!!!! and bualaaa
No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.
It's not a guarantee, though; it might happen that for some JAX and Pytorch versions there's no intersecting CUDA version.
I hit a similar issue when installing pytorch and jax into the same conda environment: when torch is loaded first, jax.devices()
will list only cpu devices.
A short summary of diagnosis: It turns out that torch is built against cudnn version 8.7 while jaxlib is built against cudnn version 8.8 leading to an exception when executing jax._src.xla_bridge._check_cuda_versions()
.
Here follows a reproducer:
mamba create -n test-pytorch-jax pytorch pytorch-cuda=11.8 jaxlib=*=*cuda118* jax -c pytorch -c nvidia --no-channel-priority -y
mamba activate test-pytorch-jax
(note: using strict channel priority would lead to a mamba solver problem).
Import torch before checking jax.devices:
>>> import torch
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
Import torch after checking jax.devices:
>>> import jax
>>> jax.devices()
[cuda(id=0), cuda(id=1)]
>>> import torch
>>> jax.__version__
'0.4.23'
>>> torch.__version__
'2.1.2'
>>> from torch._C import _cudnn
>>> _cudnn.getCompileVersion()
(8, 7, 0)
Notices that the result of jaxlib.cuda._versions.cudnn_get_version()
depends on whether torch
was imported before or after calling jaxlib.cuda._versions.cudnn_get_version
:
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> import torch
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
vs
>>> import torch
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8700
that qualifies as an incompatible linkage issue: since libcudnn is dynamically loaded, the result of cudnnGetVersion ought to give the version of loaded library and not of the version of the library that a software was built against. The behavior above suggests that torch was linked with libcudnn statically.
A possible resolution: Note that cuDNN minor releases are backward compatible with applications built against the same or earlier minor release. Hence, as long as jaxlib and torch are built against libcudnn with the same major version (8), the jax version check ought to ignore cudnn minor versions. Here is a patch:
diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 7977f6329..17c14bc5a 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -263,7 +263,7 @@ def _check_cuda_versions():
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
- scale_for_comparison=100,
+ scale_for_comparison=1000,
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.
The latest version pair I could find that were compatible with each other were jax[cuda11-pip,cuda11_pip]==0.4.10
and torch==2.2.1+cu118
. The main conflict in later versions for jax is for cudnn, which want >8.8
, but torch wants ==8.7
.
One way to check this would be:
cat > requirements.in <<EOF
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url=https://download.pytorch.org/whl
jax[cuda11_pip]
torch==2.2.1+cu118
EOF
pip-compile
# Check the contents of requirements.txt.
A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch. So basically jax[cuda11_pip]
and torch
in our requirements file works for us.
A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch.
How did you get this to work? I'm using conda, but after installing pytorch-cuda=12.1
I get the following error from JAX:
E RuntimeError: Unable to initialize backend 'cuda': Unable to load CUDA. Is it installed? (set JAX_PLATFORMS='' to automatically choose an available backend)
We did not have to do anything special. Just installed the two packages in a clean env, and both worked.
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)
fyi @traversaro
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):
traversaro@IITBMP014LW012:~$ mamba create -n jaxtorchcuda pytorch==2.1.2=*cuda* jaxlib==0.4.23=*cuda* jax cuda-version=12.* python==3.11.* cudatoolkit==12.*
Looking for: ['pytorch==2.1.2[build=*cuda*]', 'jaxlib==0.4.23[build=*cuda*]', 'jax', 'cuda-version=12', 'python=3.11', 'cudatoolkit=12']
conda-forge/linux-64 Using cache
conda-forge/noarch Using cache
Could not solve for environment specs
The following packages are incompatible
├─ cuda-version 12** is installable with the potential options
│ ├─ cuda-version [12.0|12.0.0] would require
│ │ └─ cudatoolkit 12.0|12.0.* , which can be installed;
│ ├─ cuda-version 12.1 would require
│ │ └─ cudatoolkit 12.1|12.1.* , which can be installed;
│ ├─ cuda-version 12.2 would require
│ │ └─ cudatoolkit 12.2|12.2.* , which can be installed;
│ ├─ cuda-version 12.3 would require
│ │ └─ cudatoolkit 12.3|12.3.* , which can be installed;
│ └─ cuda-version 12.4 would require
│ └─ cudatoolkit 12.4|12.4.* , which can be installed;
├─ cudatoolkit 12** does not exist (perhaps a typo or a missing channel);
├─ jaxlib 0.4.23 *cuda* is installable with the potential options
│ ├─ jaxlib 0.4.23 would require
│ │ └─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│ ├─ jaxlib 0.4.23 would require
│ │ └─ libgrpc >=1.62.1,<1.63.0a0 , which requires
│ │ └─ libprotobuf >=4.25.3,<4.25.4.0a0 , which can be installed;
│ ├─ jaxlib 0.4.23 would require
│ │ ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│ │ └─ libgrpc >=1.59.3,<1.60.0a0 , which requires
│ │ └─ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
│ └─ jaxlib 0.4.23 would require
│ └─ python_abi 3.12.* *_cp312, which requires
│ └─ python 3.12.* *_cpython, which can be installed;
├─ python 3.11** is not installable because it conflicts with any installable versions previously reported;
└─ pytorch 2.1.2 *cuda* is installable with the potential options
├─ pytorch 2.1.2 would require
│ ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
│ └─ libtorch 2.1.2.* with the potential options
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_generic_*_0, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_generic_*_1, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_generic_*_3, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_mkl_*_100, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_mkl_*_101, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cpu_mkl_*_103, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cuda112_*_300, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
│ │ └─ pytorch 2.1.2 cuda112_*_301, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│ │ ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
│ │ └─ pytorch 2.1.2 cuda118_*_301, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ ├─ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
│ │ └─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
│ ├─ libtorch 2.1.2 would require
│ │ └─ pytorch 2.1.2 cuda118_*_300, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ ├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
│ │ ├─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
│ │ └─ pytorch 2.1.2 cuda120_*_301, which can be installed;
│ ├─ libtorch 2.1.2 would require
│ │ ├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
│ │ └─ pytorch 2.1.2 cuda120_*_303, which can be installed;
│ └─ libtorch 2.1.2 would require
│ └─ pytorch 2.1.2 cuda120_*_300, which can be installed;
├─ pytorch 2.1.2 would require
│ └─ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
├─ pytorch 2.1.2 would require
│ └─ python_abi 3.12.* *_cp312, which can be installed (as previously explained);
└─ pytorch 2.1.2 would require
├─ cuda-version >=12.0,<13 , which can be installed (as previously explained);
└─ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported.
Once a conda-forge pytorch version gets compiled with libprotobuf==4.25.3 (i.e. https://github.com/conda-forge/pytorch-cpu-feedstock/pull/228 is ready and merged, big thanks to who the pytorch and jax conda-forge mantainers) it should be possible to install both jax and pytorch with cuda enabled and using cuda 12 just with conda-forge packages.
JAX 0.4.26 relaxed our CUDA version dependencies so the minimum CUDA version for JAX is 12.1. This is a version also supported by PyTorch. Try it out! We're going to try to make sure our supported version range overlaps with at least one PyTorch release.
We dropped support for CUDA 11, note.
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):
After a bunch of fixes from both jax and pytorch mantainers, now (late May 2024) it is possible to just install jax and pytorch from conda-forge on Linux and out of the box they will work with GPU/CUDA support without the need to use any other conda channel:
$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>>
If for some reason this command does not install the cuda-enabled jax, perhaps you are still using the classic conda solver, in that case you can force the installation of cuda-enabled jax and pytorch with:
conda create -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*
However, this is not necessary if you are using a recent conda install that defaults to use the conda-libmamba-solver
, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community .
Can someone please point out the correct version necessary to get pytorch and jax both with GPU support on CUDA 12 as of July 2024? I would prefer it to be a standard venv rather than a conda env, but either is fine.
@varadVaidya totally by chance I follow this issue, but in general you may have more success in finding help by using official jax help channels (see https://jax.readthedocs.io/en/latest/beginner_guide.html#finding-help), rather then posting in closed issues.
More on topic, I have no idea about pip/venv with cuda, but for conda the procedure posted in https://github.com/google/jax/issues/18032#issuecomment-2132399059 is working fine for me (when I originally posted the message I forgot to add the -c conda-forge
to ensure it works fine also on anaconda
or miniconda
installation of conda that use defaults
instead of conda-forge
, I just fixed that to avoid confusion).
@eliseoe @bebark @shaikalthaf4 By change I just noticed that you added a 👎🏽 reaction to my previous comment, any reason for doing so? Just fyi, authors do not get (at least by default) notifications for post reactions.
@traversaro I found that running your command with conda
will install:
jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0
whereas with mamba
the correct version is installed:
mamba create -c conda-forge -n jaxpytorch pytorch jax
jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64
Perhaps this is why you got 3 thumbs down
@traversaro I found that running your command with
conda
will install:jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0
whereas withmamba
the correct version is installed:mamba create -c conda-forge -n jaxpytorch pytorch jax
jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64
Perhaps this is why you got 3 thumbs down
@peterch405
Interestingly, in my system with:
root@DESKTOP-T0NQNLN:~# conda info
active environment : None
shell level : 0
user config file : /root/.condarc
populated config files : /root/miniforge3/.condarc
/root/.condarc
conda version : 24.3.0
conda-build version : not installed
python version : 3.10.14.final.0
solver : libmamba (default)
virtual packages : __archspec=1=skylake
__conda=24.3.0=0
__cuda=12.0=0
__glibc=2.39=0
__linux=5.15.153.1=0
__unix=0=0
base environment : /root/miniforge3 (writable)
conda av data dir : /root/miniforge3/etc/conda
conda av metadata url : None
channel URLs : https://conda.anaconda.org/conda-forge/linux-64
https://conda.anaconda.org/conda-forge/noarch
package cache : /root/miniforge3/pkgs
/root/.conda/pkgs
envs directories : /root/miniforge3/envs
/root/.conda/envs
platform : linux-64
user-agent : conda/24.3.0 requests/2.31.0 CPython/3.10.14 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/24.04 glibc/2.39 solver/libmamba conda-libmamba-solver/24.1.0 libmambapy/1.5.8
UID:GID : 0:0
netrc file : None
offline mode : False
the command
conda create -n conda-forge -n jaxpytorch pytorch jax
installs the cuda jax, but indeed:
conda create --solver=classic -n conda-forge -n jaxpytorch pytorch jax
installs cpu jax. Perhaps you are using an old conda version that is using the classic solver by default? (You can see this if you report the conda info
output, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community).
However, even with the classic solver forcing the solver to install the cuda version of jaxlib and pytorch works as expected (even if the classic solver is much slower):
conda create --solver=classic -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*
I edited the original comment accordingly.
You are right, I'm using the classic solver:
active environment : base
active env location : /home/chovanec/miniconda3
shell level : 1
user config file : /home/chovanec/.condarc
populated config files : /home/chovanec/.condarc
conda version : 23.1.0
conda-build version : not installed
python version : 3.10.8.final.0
virtual packages : __archspec=1=x86_64
__cuda=12.3=0
__glibc=2.31=0
__linux=5.15.153.1=0
__unix=0=0
base environment : /home/chovanec/miniconda3 (writable)
conda av data dir : /home/chovanec/miniconda3/etc/conda
conda av metadata url : None
channel URLs : https://conda.anaconda.org/bioconda/linux-64
https://conda.anaconda.org/bioconda/noarch
https://conda.anaconda.org/conda-forge/linux-64
https://conda.anaconda.org/conda-forge/noarch
https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /home/chovanec/miniconda3/pkgs
/home/chovanec/.conda/pkgs
envs directories : /home/chovanec/miniconda3/envs
/home/chovanec/.conda/envs
platform : linux-64
user-agent : conda/23.1.0 requests/2.28.1 CPython/3.10.8 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/20.04.6 glibc/2.31
UID:GID : 1000:1000
netrc file : None
offline mode : False
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch >>> import jax >>> torch.cuda.is_available() True >>> jax.devices() [cuda(id=0)] >>> import jaxlib.cuda._versions >>> jaxlib.cuda._versions.cudnn_get_version() 8902 >>> torch._C._cudnn.getCompileVersion() (8, 9, 2)
conda list fyi @traversaro
Thanks for the solution, however i have found a possible bug that the jax numpy cannot initialize an array which size is bigger than (2, 52, 10) with both jax and jaxlib version are 0.4.30, so i have to downgrade the jax version to 0.4.23 and then works just fine, so for the insurance, the command could be like
conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
python 3.12 is too newer to some commonly used pkgs
Just a curiosity, are you actually getting any packages from the nvidia
or pytorch
channel? If conda-forge
channel is used and you are using strict priority, all the packages you get should come from conda-forge
, and so I guess you could drop the -c nvidia -c pytorch
from your command. However, you can check this by calling conda list
and checking from where packages are installed.
Just a curiosity, are you actually getting any packages from the
nvidia
orpytorch
channel? Ifconda-forge
channel is used and you are using strict priority, all the packages you get should come fromconda-forge
, and so I guess you could drop the-c nvidia -c pytorch
from your command. However, you can check this by callingconda list
and checking from where packages are installed.
I'm not sure, maybe later i can do a test,thx for the noticing
Just a curiosity, are you actually getting any packages from the
nvidia
orpytorch
channel? Ifconda-forge
channel is used and you are using strict priority, all the packages you get should come fromconda-forge
, and so I guess you could drop the-c nvidia -c pytorch
from your command. However, you can check this by callingconda list
and checking from where packages are installed.
sorry for the late reply, here is the outputs since the jax and jax cuda lib are manually reinstalled by the pypi, i guess yes that the packages are privileged installed from conda-forge :)
Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.
Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with
conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.
In my case, the conflicts comes from the torch and jaxlib stick to different cudnn version, formerly i didn't seek to conda-forge to install the cudatoolkit compatible for both torch and jaxlib. i use the pip command from the official jax documentation btw.
Ok, but in that case it is probably a good idea not to install jax
and jaxlib
from conda, and only install it from pip.
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.
i think the only reason for the jax
and 'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax
and jaxlib
from pip
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.i think the only reason for the
jax
and'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstalljax
andjaxlib
from pip
But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.i think the only reason for the
jax
and'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstalljax
andjaxlib
from pipBut conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.
you are right, accidentally i use the pip install, and it just found the cudnn version meets the requirement lol.
@traversaro apologies to have to revive this issue, but your solution does not work for me:
$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
The error message suggests I need cudnn>=9.1.0
, however the most recent version on conda-forge appears to be 8.9.7.29
.
System info:
Operating System: Ubuntu 22.04.4 LTS
GPU: NVIDIA GeForce GTX 1060 6GB
Graphics Driver: NVIDIA driver metapackage from nvidia-driver-535
mamba list
below the fold:
UPDATE: jax=0.4.28
~appears to~ does work, so this looks like a bug introduced recently in JAX. cc @hawkinsp
@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!
@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!
Thanks @lucascolley, indeed it seems a regression in the conda-forge jax package 0.4.31, I opened https://github.com/conda-forge/jaxlib-feedstock/issues/277 to track the problem.
Description
When I only pip the latesd jax with cuda(pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html), I can use the jax with gpu.But when I pip the torch(pip install torch) later, Ican't use the jax with gpu,it remind me that cuda or cusolver's version is older than jax's.Why? Can Older jax version avoid it?Then how can I pip the jax[cuda] with relevant version?
What jax/jaxlib version are you using?
jax-0.4.18 jaxlib-0.4.18+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
3.10.9/Linux
NVIDIA GPU info
No response