Closed wula2048 closed 9 months ago
I think the problem is that you are using CUDA 12 but the Windows builds of JAX only support CUDA 11 currently.
I downloaded CUDA 11.8 from the official NVIDIA website and successfully installed it, but I still encountered the same error.
Hi,
Sorry I should have asked before if you installed via conda or pip. If you install keypoint-moseq using one of the conda env files, then conda will install its own copy of CUDA that is separate from the system install. In theory that should have worked. I can re-test that installation method on Monday when I have access to a Windows machine. In the mean time, you are welcome to try the pip installation route, which would involve:
Hello,
Following your instructions, I deleted the original 'keypoint_moseq' environment and installed it using Pip. Here are the steps I took in the terminal:
conda env remove -n keypoint_moseq conda create -n keypoint_moseq python=3.9 pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl pip install keypoint-moseq conda install -c conda-forge pytables conda install numpy=1.24 python -m ipykernel install --user --name=keypoint_moseq python import keypoint_moseq as kpms
Subsequently, I encountered the following error:
(keypoint_moseq) C:\Users\85033>python
Python 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:40:31) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import keypoint_moseq as kpms
2023-10-15 20:49:22.276586: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:454] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: ' If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
Make sure that CUDA 11.1 and cudnn 8.2 are being used in the environment: nvcc --version output:
(keypoint_moseq) C:\Users\85033>nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:54:10_Pacific_Daylight_Time_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.relgpu_drvr455TC455_06.29190527_0
cudnn version:
Upon searching for this error on Google, I found a report of a similar issue at 'https://github.com/cloudhan/jax-windows-builder/issues/19'. I realized that this could be an issue with JAX. I followed the instructions at 'https://github.com/cloudhan/jax-windows-builder/issues/19#issuecomment-1514294488' to set up a new environment, but unfortunately, the problem persists, and I'm still getting the same error.
Hmm I'm not sure what to suggest from here. It seems like a jax problem and so might be worth posting an issue elsewhere for following other suggestions from google. If you do manage to create an env where you are able to import jax and run e.g. jax.random.PRNGKey(0)
without error, then I am happy help make sure you are able to install the rest of the keypoint-moseq package on top of that.
Thank you for your quick response. I will continue to try installing JAX in a new environment.
I'm now using keypoint-moseq via WSL2 (cuda11.8,cudnn8.8) and it's working fine for the demo data, so I'll close this issue.
Hi all ,
I am trying to use keypoint-moseq and after installing the virtual environment according to the conda method, I am trying to validate the process using the demo data you provided.But getting the following error when running
model = kpms.init_model(data, pca=pca, **config())
Additional Information
Operating System
keypoint-moseq version
cuda and cudnn version
nvcc --version output
nvidia-smi output![image](https://github.com/wula2048/-/assets/87176207/4ce6dbbd-ec6d-4928-b7a5-dd1d90aed53b)
I realized that a similar issue has been reported before #69, but I have not seen an effective solution, and I look forward to your help in resolving this issue, thank you!