dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
71 stars 28 forks source link

cuSolver internal error #69

Closed micahbaldonado closed 1 year ago

micahbaldonado commented 1 year ago

Hi all,

I have attempted the Windows GPU local installation for keypoint moseq via the conda and pip methods, and I have run into errors for both. I got the CPU method to work via conda, but with the amount of data I need to process, it's too slow to get the job done.

I am working on a Microsoft Windows Version 22H2 (OS Build 19045.3086). Below are the steps I have taken for each route and the terminal/kernel results:


conda method (running as administrator because sometimes I don't have the proper permissions; working on my lab computer)

from terminal: 1) rmdir /S keypoint-moseq (I just make sure I'm starting from a clean slate) 2) git clone https://github.com/dattalab/keypoint-moseq 3) chdir keypoint-moseq 4) conda env remove --name keypoint_moseq (again clean slate) 5) conda env create -f conda_envs\environment.win64_gpu.yml 6) conda activate keypoint_moseq 7) conda install -c conda-forge pytables the above line has led to errors, so I often use conda install -c conda-forge pytables=3.8.0 or conda install -c conda-forge pytables=3.7.0 if pytables is being buggy. 8) python -m ipykernel install --user --name=keypoint_moseq 9) at this point, when I run nvcc --version, this is my output: (keypoint_moseq) C:\Windows\System32\keypoint-moseq>nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:36:15_Pacific_Daylight_Time_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0

10) jupyter notebook

from jupyter notebook: 11) new > keypoint_moseq (notebook) 12) when I run "import keypoint_moseq as kpms," the kernel outputs: ImportError: Numba needs NumPy 1.24 or less

I usually go back and pip install a better numpy version: pip install numpy==1.24.0 Since this leads to ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. seaborn 0.12.2 requires numpy!=1.24.0,>=1.17, but you have numpy 1.24.0 which is incompatible. pip install numpy==1.23.0, and this usually does the trick.

13) jupyter notebook again, running "import keypoint_moseq as kpms" this time, this line is successful, so I continue forward like so:

-line 1

import keypoint_moseq as kpms

-line 2

project_dir = "C:\Users\nico\Desktop\keypoint_moseq\micahsfc_test" config = lambda: kpms.load_config(project_dir) -line3 dlc_config = "C:\Users\nico\Desktop\keypoint_moseq\micahsfc_test\config.yaml" kpms.setup_project(project_dir, deeplabcut_config=dlc_config, overwrite=True)

-line 4

kpms.update_config( project_dir, video_dir= "C:\Users\nico\Desktop\keypoint_moseq\micahsfc_test\videos", anterior_bodyparts=['nose'], posterior_bodyparts=['lumbar spine'], use_bodyparts=[ 'nose', 'head', 'leftear','rightear','cervical spine', 'thoracic spine', 'lumbar spine', 'tailbase'])

-line 5

# load data (e.g. from DeepLabCut) keypoint_data_path = "C:\Users\nico\Desktop\keypoint_moseq\micahsfc_test\videos" # can be a file, a directory, or a list of files coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut')

# format data for modeling data, labels = kpms.format_data(coordinates, confidences=confidences, **config())

-line 6 (this line works and I have a proper PC analysis)

pca = kpms.fit_pca(data, config()) kpms.save_pca(pca, project_dir)

kpms.print_dims_to_explain_variance(pca, 0.90) kpms.plot_scree(pca, project_dir=project_dir) kpms.plot_pcs(pca, project_dir=project_dir, **config())

# use the following to load an already # pca = kpms.load_pca(project_dir) -line 7

kpms.update_config(project_dir, latent_dim=4)

-line 8

# optionally update kappa in the config before initializing kpms.update_config(project_dir=project_dir, kappa=1.25e5) # made kappa 1.25e5 for the square videos!! this worked for the AR-HMM fitting.

# initialize the model model = kpms.init_model(data, pca=pca, **config())

upon running this final step (line 8), I get the following error, which I have not yet found a solution to: File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:45, in init_ar_params(seed, num_states, nu_0, S_0, M_0, K_0, **kwargs) 43 seeds = jr.split(seed, num_states) 44 in_axes = (0, na, na, na, na) ---> 45 Ab, Q = jax.vmap(sample_mniw, in_axes)( 46 seeds, nu_0, S_0, M_0, K_0) 47 return Ab, Q

[... skipping hidden 3 frame]

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\debugging.py:249, in check_output..decorator..wrapper(*args, kwargs) 246 @functools.wraps(func) 247 def wrapper(*args, *kwargs): 248 try: --> 249 result = func(args, kwargs) 250 if not hasattr(sys, '_checked_function_args') or not sys._checked_function_args.active: 251 return result

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:49, in sample_mniw(seed, nu, S, M, K) 47 @nan_check 48 def sample_mniw(seed, nu, S, M, K): ---> 49 sigma = sample_invwishart(seed, S, nu) 50 A = sample_mn(seed, M, sigma, K) 51 return A, sigma

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\debugging.py:249, in check_output..decorator..wrapper(*args, kwargs) 246 @functools.wraps(func) 247 def wrapper(*args, *kwargs): 248 try: --> 249 result = func(args, kwargs) 250 if not hasattr(sys, '_checked_function_args') or not sys._checked_function_args.active: 251 return result

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:40, in sample_invwishart(seed, S, nu) 38 x = jnp.diag(jnp.sqrt(sample_chi2(chi2_seed, nu - jnp.arange(n)))) 39 x = x.at[jnp.triu_indices_from(x,1)].set(jr.normal(norm_seed, (n*(n-1)//2,))) ---> 40 R = jnp.linalg.qr(x,'r') 42 chol = jnp.linalg.cholesky(S) 44 T = jax.scipy.linalg.solve_triangular(R.T,chol.T,lower=True).T

[... skipping hidden 19 frame]

File ~\anaconda3\envs\keypoint_moseq\lib\site-packages\jaxlib\gpu_solver.py:308, in _orgqr_mhlo(platform, gpu_solver, dtype, a, tau) 305 assert tau_dims[:-1] == dims[:-2] 306 k = tau_dims[-1] --> 308 lwork, opaque = gpu_solver.build_orgqr_descriptor( 309 np.dtype(dtype), batch, m, n, k) 311 layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) 312 i32_type = ir.IntegerType.get_signless(32)

RuntimeError: jaxlib/cuda/cusolver_kernels.cc:45: operation cusolverDnCreate(&handle) failed: cuSolver internal error

if I log: python -c "import jax; print(jax.devices())" into the terminal, my output is: [StreamExecutorGpuDevice(id=0, process_index=0)]


pip method (again running as administrator, but I have tried both with/w/out administrator) 1) download CUDA version 11.1.0 2) run nvcc --version: (verifying CUDA version 11.1.0 (base) C:\Windows\system32>nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2020 NVIDIA Corporation Built on Tue_Sep_15_19:12:04_Pacific_Daylight_Time_2020 Cuda compilation tools, release 11.1, V11.1.74 Build cuda_11.1.relgpu_drvr455TC455_06.29069683_0 3) rmdir /S keypoint-moseq (clean slate) conda env remove --name keypoint_moseq (clean slate) conda create -n keypoint_moseq python=3.9 conda activate keypoint_moseq 4) Windows (GPU) pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl 5) pip install keypoint-moseq 6) conda install -c conda-forge pytables=3.8.0 ( I have tried the other versions of this command) 7) python -m ipykernel install --user --name=keypoint_moseq 8) jupyter notebook

from jupyter notebook: 9) when I run "import keypoint_moseq as kpms," the kernel immediately dies. I checked the troubleshooting errors, and so I tried running import jax jax.config.update('jax_enable_x64', False) in the same line as "import keypoint_moseq as kpms," but this did not work. Below is my output upon running: nvidia-smi image

I am unsure why the nvidia-smi and nvcc --version report different cuda versions, but this does not explain why the conda method failed.

when I run "python -c "import jax; print(jax.devices())"," the terminal outputs: [StreamExecutorGpuDevice(id=0, process_index=0)]

10) after the kernel fails, I have tried running the following in the terminal: "conda install -c nvidia cuda-nvcc" and this actually allows the jupyter notebook code to work all the way up until the same point as the conda method @ line 8, the cusolve error.


I suspect the pip errors may be due to my cuDNN version, which I am having difficulty locating, but I did my best to install that in my computer previously, changing the path names after extracting the cuDNN package. Regardless, the conda method should still work based on the steps I outlined above. Thank you.

calebweinreb commented 1 year ago

Hey! A couple thoughts:

  1. For the conda install, maybe you're encountering a memory allocation issue? (see this issue).

    • Can you confirm that the VRAM is completely free (0MiB/ 12882MiB) right before you start going through the notebook?
    • Can you try running %env XLA_PYTHON_CLIENT_PREALLOCATE=false at the top of the notebook? (as suggested here)
  2. For the pip install, "nvcc" and "nvidia-smi" show different versions because the latter is showing the graphics driver version. In any case, maybe we can first confirm that you have the correct cuDNN version (8.2). There are a bunch of stackexchange threads on how to do that. For example you could try the following (modifying the path as needed)

    type "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\include\cudnn.h" | findstr "CUDNN_MAJOR CUDNN_MINOR CUDNN_PATCHLEVEL"
micahbaldonado commented 1 year ago

Conda

For the conda method, I tried running %env XLA_PYTHON_CLIENT_PREALLOCATE=false at the top of the notebook, and I still got the cuSolver error. Previously, I tried doing this with "set XLA_PYTHON_CLIENT_PREALLOCATE = false" in the terminal.

image The above image is before running anything in the terminal. image The above image is right before I start going through the notebook but after I enter the terminal commands

After running the commands in jupyter notebook, the performance tab of my task manager looks like this: image

The utilization rarely went above 6% as I ran the required scripts in the notebook. Once, the number was briefly at 14% and 100% respectively, but this did not impact anything. I ran all the scripts a separate time and there was no 100% usage, so I do not believe this is an allocation problem, rather a problem inherent between the interaction between jax and windows software.


Pip

%env XLA_PYTHON_CLIENT_PREALLOCATE=false still led to the pip method crashing, which is strange because I figured the crash would have to do with GPU memory allocation. However, when I ran the command, the "Shared GPU memory" did not exceed 0.8/31.9 GB at any given time.

image I get the above image when I attempt to find my cuDNN version. I find this bizarre because there is a file called cudnn.h with this exact path (type "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\include\cudnn.h" | findstr "CUDNN_MAJOR CUDNN_MINOR CUDNN_PATCHLEVEL").

I retried my cuDNN installation, this time trying the following method in https://forums.developer.nvidia.com/t/installing-cudnn/229612/9

image

With the above changes to my installation route, I still do not have a terminal response with type "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\include\cudnn.h" | findstr "CUDNN_MAJOR CUDNN_MINOR CUDNN_PATCHLEVEL"

I believe the cuDNN installation is especially tricky because per the NVIDIA instructions: image

However, the path C:\Program Files\NVIDIA\CUDNN\v8.x\bin does not exist; specifically, there is no CUDNN path, so you have to make your own, which is what the above solution tried with the C:\Program Files\NVIDIA GPU Computing Toolkit\CUDNN\v12.x\bin path. Thank you.

micahbaldonado commented 1 year ago

As you suggested from our discussion yesterday, I uninstalled all CUDA related software that wasn't version 11.1 and added cuDNN 8.6.0 (compatible with CUDA 11.X) paths and files to my system.

While the line type "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\include\cudnn.h" | findstr "CUDNN_MAJOR CUDNN_MINOR CUDNN_PATCHLEVEL" still does not output anything in the terminal, I am pretty sure I installed and modified the path variables necessary for cuDNN correctly.

To address the pip issue, I went through all the dependencies in the setup.cfg file, importing each to see which one caused the line "import keypoint_moseq as kpms" to kill the kernel.

image

The kernel dies from importing jax_moseq with CUDA version 11.1 and cuDNN version 8.6.0. As mentioned prior, updating CUDA resolves this error but leads to a cuSolver error later. Thanks.

calebweinreb commented 1 year ago

What about just import jax? Or the jax_moseq deps?

numba>=0.56.4
dynamax
chex==0.1.6
micahbaldonado commented 1 year ago

image

Import jax works fine. Same for all the jax_moseq dependencies, and I don't think you can import a package with a version specification in Python, since that would have already been specified with pip install keypoint-moseq

calebweinreb commented 1 year ago

Interesting. I guess the next step would be to start importing the individual components of jax_moseq. Or maybe try importing in a python terminal (rather than a notebook) and see if a specific error comes up? Or maybe even with a notebook there's some error printed in the terminal from which you launched the notebook?

micahbaldonado commented 1 year ago

image

This is the error that comes up in the terminal, and it's the same one that shows up when you run "import jax_moseq" in the notebook

image

Upon looking at the following case which looks fairly similar: https://github.com/deepmind/alphafold/issues/122

I tried adding the pxtas bin directory and file path into the "path" variable in "Environment Variables," but I still got the same error. I'm a bit in the dark on how I would import individual components of jax_moseq outside of the ones in the config file.

calebweinreb commented 1 year ago

Ahh this sounds very frustrating. I'm sorry you're having so many issues! The only thing I can think of at this point is that your operating system (Windows 12) is too new for any of the available jax installs. So one option would be downgrade your OS to Windows 11 or 10. Another option would be to just use google colab. I think running linux with WSL might not work, since it relies on the global cuda install (see https://github.com/dattalab/keypoint-moseq/issues/39#issuecomment-1540707376)

micahbaldonado commented 1 year ago

I'm actually using Windows 10, but I might go the Linux route. Thank you so much for the help in all this. Really appreciate it.

JohannaMankel commented 1 year ago

I'm running into these exact same problems on Win11 and was also not able to resolve it

calebweinreb commented 1 year ago

I'm sorry to hear that! I'm actually not sure if its possible to install GPU-enabled jax on Windows 11. However if you would like to try, I'm happy to help debug. If so, please open a separate issue and include details such as CUDA and cudnn version (if you have installed them), how you installed kpms, and the exact error you got.