Open eterevsky opened 1 year ago
I cannot reproduce it with CUDA 11.5. If the same error don't occur with new version of CUDA, then I think you better to report to jax team directly.
Note, cuda111
build is compatible with from CUDA 11.1 to 11.7
Thank you for the quick response. I didn't know other versions of CUDA are supported. The README says "Currently, only CPU and CUDA 11.1 are supported." Does jaxlib also support CUDA 11.8?
I tried upgrading CUDA to 11.7.1 and CuDNN to 8.7. The installed driver version is 527.56 which is newer than 516.94 that comes with CUDA 11.7.1, but CUDA installer refuses to downgrade it. (Could this be a problem?)
Then I reinstalled jaxlib and jax:
$ pip install --force-reinstall ../Downloads/jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl --use-deprecated legacy-resolver
[...]
Successfully installed jaxlib-0.3.25 numpy-1.24.0 scipy-1.9.3
$ pip install --force-reinstall "jax===0.3.25" --use-deprecated legacy-resolver
[...]
Successfully installed jax-0.3.25 numpy-1.24.0 opt_einsum-3.3.0 scipy-1.9.3 typing_extensions-4.4.0
(Is this the correct way to install them?)
Then again I tried to use JAX and had the same result:
$ python
Python 3.10.9 (tags/v3.10.9:1dd9be6, Dec 6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3])
>>> a
DeviceArray([1, 2, 3], dtype=int32)
>>> a + a
2022-12-19 08:19:19.624302: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: ' If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
Based on you commandline, you correctly installed jaxlib and jax. This is indeed very strange. Are you using python distributed from Microsoft Store? If so, pls stop using it and avoiding it like a plague.
OK, I was thinking that you are not the only victim of Microsoft Python https://github.com/cloudhan/jax-windows-builder/issues/16, but it turn out that you are stepping into the same river twice.
That would be very embarrassing, so I tried to double check that I'm using the correct version of Python before submitting this bug. This is Python 3.10.9 installed from https://www.python.org/downloads/ just yesterday. I just verified that I don't have MS Store version installed and that pip comes from the correct package.
For some context, I made JAX work last time (thank you!), but for some reason its performance was significantly lower than on Ubuntu on the same machine, so I didn't use it on Windows since then. Now I decided to give it another go.
I've just installed the jax + jaxlib from jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl and I'm getting the following error: (skipping the middle code) a + a 2022-12-18 15:28:18.688396: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: ' If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
I just installed jax for windows and your code works. I will outline the steps I took. Perhaps it will help out.
update: I tried the below step with CUDA 11.7 & it still works.
conda create -n jaxg python=3.9 -c conda-forge
conda
doesn't install all components. Hence we need to install cuda-nvcc from nvidia
repo to also install ptxas
conda install cuda-nvcc -c nvidia/label/cuda-11.3.0
conda install cudatoolkit=11.1 cudnn=8.2 -c conda-forge
cudatoolkit 11.1.1 hb074779_11 conda-forge/win-64 868MB
cudnn 8.2.1.32 h754d62a_0 conda-forge/win-64 700MB
https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.25+cuda11.cudnn82-cp39-cp39-win_amd64.whl
pip install jax[cuda111]==0.3.25 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
pip install jaxlib[cuda111]==0.3.25 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
(jaxg) C:\condadir>nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Mar_21_19:24:09_Pacific_Daylight_Time_2021
Cuda compilation tools, release 11.3, V11.3.58
Build cuda_11.3.r11.3/compiler.29745058_0
(jaxg) C:\condadir>ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Mar_21_19:22:56_Pacific_Daylight_Time_2021
Cuda compilation tools, release 11.3, V11.3.58
Build cuda_11.3.r11.3/compiler.29745058_0
(jaxg) C:\condadir>python
Python 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:28:38) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3])
>>> a
DeviceArray([1, 2, 3], dtype=int32)
>>> a + a
DeviceArray([2, 4, 6], dtype=int32)
>>>
I've just installed the jax + jaxlib from jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl and I'm getting the following error:
ptxas version: