cloudhan / jax-windows-builder

A community supported Windows build for jax.
346 stars 17 forks source link

"ptxas returned an error during compilation of ptx to sass" #19

Open eterevsky opened 1 year ago

eterevsky commented 1 year ago

I've just installed the jax + jaxlib from jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl and I'm getting the following error:

Python 3.10.9 (tags/v3.10.9:1dd9be6, Dec  6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3])
>>> a
DeviceArray([1, 2, 3], dtype=int32)
>>> a + a
2022-12-18 15:28:18.688396: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: '  If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.

ptxas version:

$ ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Tue_Sep_15_19:11:24_Pacific_Daylight_Time_2020
Cuda compilation tools, release 11.1, V11.1.74
Build cuda_11.1.relgpu_drvr455TC455_06.29069683_0
cloudhan commented 1 year ago

I cannot reproduce it with CUDA 11.5. If the same error don't occur with new version of CUDA, then I think you better to report to jax team directly.

Note, cuda111 build is compatible with from CUDA 11.1 to 11.7

eterevsky commented 1 year ago

Thank you for the quick response. I didn't know other versions of CUDA are supported. The README says "Currently, only CPU and CUDA 11.1 are supported." Does jaxlib also support CUDA 11.8?

I tried upgrading CUDA to 11.7.1 and CuDNN to 8.7. The installed driver version is 527.56 which is newer than 516.94 that comes with CUDA 11.7.1, but CUDA installer refuses to downgrade it. (Could this be a problem?)

Then I reinstalled jaxlib and jax:

$ pip install --force-reinstall ../Downloads/jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl --use-deprecated legacy-resolver
[...]
Successfully installed jaxlib-0.3.25 numpy-1.24.0 scipy-1.9.3
$ pip install --force-reinstall "jax===0.3.25" --use-deprecated legacy-resolver
[...]
Successfully installed jax-0.3.25 numpy-1.24.0 opt_einsum-3.3.0 scipy-1.9.3 typing_extensions-4.4.0

(Is this the correct way to install them?)

Then again I tried to use JAX and had the same result:

$ python
Python 3.10.9 (tags/v3.10.9:1dd9be6, Dec  6 2022, 20:01:21) [MSC v.1934 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3])
>>> a
DeviceArray([1, 2, 3], dtype=int32)
>>> a + a
2022-12-19 08:19:19.624302: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: '  If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
cloudhan commented 1 year ago

Based on you commandline, you correctly installed jaxlib and jax. This is indeed very strange. Are you using python distributed from Microsoft Store? If so, pls stop using it and avoiding it like a plague.

cloudhan commented 1 year ago

OK, I was thinking that you are not the only victim of Microsoft Python https://github.com/cloudhan/jax-windows-builder/issues/16, but it turn out that you are stepping into the same river twice.

eterevsky commented 1 year ago

That would be very embarrassing, so I tried to double check that I'm using the correct version of Python before submitting this bug. This is Python 3.10.9 installed from https://www.python.org/downloads/ just yesterday. I just verified that I don't have MS Store version installed and that pip comes from the correct package.

For some context, I made JAX work last time (thank you!), but for some reason its performance was significantly lower than on Ubuntu on the same machine, so I didn't use it on Windows since then. Now I decided to give it another go.

fortunewalla commented 1 year ago

I've just installed the jax + jaxlib from jaxlib-0.3.25+cuda11.cudnn82-cp310-cp310-win_amd64.whl and I'm getting the following error: (skipping the middle code) a + a 2022-12-18 15:28:18.688396: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:453] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: ' If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.

I just installed jax for windows and your code works. I will outline the steps I took. Perhaps it will help out.

update: I tried the below step with CUDA 11.7 & it still works.

1. Create a new conda environment

conda create -n jaxg python=3.9 -c conda-forge

2. Installing cudatoolkit from official conda doesn't install all components. Hence we need to install cuda-nvcc from nvidia repo to also install ptxas

conda install cuda-nvcc -c nvidia/label/cuda-11.3.0

3. We have to install cudatoolkit 11.1 & cudnn 8.2 for jaxlib to work.

conda install cudatoolkit=11.1 cudnn=8.2 -c conda-forge

   cudatoolkit    11.1.1  hb074779_11  conda-forge/win-64     868MB
   cudnn        8.2.1.32  h754d62a_0   conda-forge/win-64     700MB

4. Wheel URL of the suitable jaxlib

https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.25+cuda11.cudnn82-cp39-cp39-win_amd64.whl

5. Command to install jaxGPU using pip

pip install jax[cuda111]==0.3.25 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

pip install jaxlib[cuda111]==0.3.25 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

6. Checking NVCC

(jaxg) C:\condadir>nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Mar_21_19:24:09_Pacific_Daylight_Time_2021
Cuda compilation tools, release 11.3, V11.3.58
Build cuda_11.3.r11.3/compiler.29745058_0

7. Checking ptxas

(jaxg) C:\condadir>ptxas --version
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Sun_Mar_21_19:22:56_Pacific_Daylight_Time_2021
Cuda compilation tools, release 11.3, V11.3.58
Build cuda_11.3.r11.3/compiler.29745058_0

8. Your Code

(jaxg) C:\condadir>python
Python 3.9.16 | packaged by conda-forge | (main, Feb  1 2023, 21:28:38) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as jnp
>>> a = jnp.array([1, 2, 3])
>>> a
DeviceArray([1, 2, 3], dtype=int32)
>>> a + a
DeviceArray([2, 4, 6], dtype=int32)
>>>