dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
64 stars 26 forks source link

cuSolver internal error #103

Closed wula2048 closed 9 months ago

wula2048 commented 9 months ago

Hi all ,

I am trying to use keypoint-moseq and after installing the virtual environment according to the conda method, I am trying to validate the process using the demo data you provided.But getting the following error when running model = kpms.init_model(data, pca=pca, **config())

JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File [d:\anaconda\envs\keypoint_moseq\lib\runpy.py:197](file:///D:/anaconda/envs/keypoint_moseq/lib/runpy.py:197), in _run_module_as_main(***failed resolving arguments***)
    196     sys.argv[0] = mod_spec.origin
--> 197 return _run_code(code, main_globals, None,
    198                  "__main__", mod_spec)

File [d:\anaconda\envs\keypoint_moseq\lib\runpy.py:87](file:///D:/anaconda/envs/keypoint_moseq/lib/runpy.py:87), in _run_code(***failed resolving arguments***)
     80 run_globals.update(__name__ = mod_name,
     81                    __file__ = fname,
     82                    __cached__ = cached,
   (...)
     85                    __package__ = pkg_name,
     86                    __spec__ = mod_spec)
---> 87 exec(code, run_globals)
     88 return run_globals

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel_launcher.py:17](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel_launcher.py:17)
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\traitlets\config\application.py:1043](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/traitlets/config/application.py:1043), in Application.launch_instance(***failed resolving arguments***)
   1042 app.initialize(argv)
-> 1043 app.start()

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelapp.py:736](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelapp.py:736), in IPKernelApp.start(***failed resolving arguments***)
    735 try:
--> 736     self.io_loop.start()
    737 except KeyboardInterrupt:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\tornado\platform\asyncio.py:195](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/tornado/platform/asyncio.py:195), in BaseAsyncIOLoop.start(***failed resolving arguments***)
    194 def start(self) -> None:
--> 195     self.asyncio_loop.run_forever()

File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\base_events.py:601](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/base_events.py:601), in BaseEventLoop.run_forever(***failed resolving arguments***)
    600 while True:
--> 601     self._run_once()
    602     if self._stopping:

File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\base_events.py:1905](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/base_events.py:1905), in BaseEventLoop._run_once(***failed resolving arguments***)
   1904     else:
-> 1905         handle._run()
   1906 handle = None

File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\events.py:80](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/events.py:80), in Handle._run(***failed resolving arguments***)
     79 try:
---> 80     self._context.run(self._callback, *self._args)
     81 except (SystemExit, KeyboardInterrupt):

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:516](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:516), in Kernel.dispatch_queue(***failed resolving arguments***)
    515 try:
--> 516     await self.process_one()
    517 except Exception:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:505](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:505), in Kernel.process_one(***failed resolving arguments***)
    504         return None
--> 505 await dispatch(*args)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:412](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:412), in Kernel.dispatch_shell(***failed resolving arguments***)
    411     if inspect.isawaitable(result):
--> 412         await result
    413 except Exception:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:740](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:740), in Kernel.execute_request(***failed resolving arguments***)
    739 if inspect.isawaitable(reply_content):
--> 740     reply_content = await reply_content
    742 # Flush output before sending the reply.

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\ipkernel.py:422](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/ipkernel.py:422), in IPythonKernel.do_execute(***failed resolving arguments***)
    421 if with_cell_id:
--> 422     res = shell.run_cell(
    423         code,
    424         store_history=store_history,
    425         silent=silent,
    426         cell_id=cell_id,
    427     )
    428 else:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\zmqshell.py:546](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/zmqshell.py:546), in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
    545 self._last_traceback = None
--> 546 return super().run_cell(*args, **kwargs)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3024](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3024), in InteractiveShell.run_cell(***failed resolving arguments***)
   3023 try:
-> 3024     result = self._run_cell(
   3025         raw_cell, store_history, silent, shell_futures, cell_id
   3026     )
   3027 finally:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3079](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3079), in InteractiveShell._run_cell(***failed resolving arguments***)
   3078 try:
-> 3079     result = runner(coro)
   3080 except BaseException as e:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\async_helpers.py:129](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/async_helpers.py:129), in _pseudo_sync_runner(***failed resolving arguments***)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3284](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3284), in InteractiveShell.run_cell_async(***failed resolving arguments***)
   3281 interactivity = "none" if silent else self.ast_node_interactivity
-> 3284 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3285        interactivity=interactivity, compiler=compiler, result=result)
   3287 self.last_execution_succeeded = not has_raised

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3466](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3466), in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
   3465     asy = compare(code)
-> 3466 if await self.run_code(code, result, async_=asy):
   3467     return True

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3526](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3526), in InteractiveShell.run_code(***failed resolving arguments***)
   3525     else:
-> 3526         exec(code_obj, self.user_global_ns, self.user_ns)
   3527 finally:
   3528     # Reset our crash handler in place

[f:\test\keypoint_moseq\demo_dlc_test.ipynb](file:///F:/test/keypoint_moseq/demo_dlc_test.ipynb) 单元格 8 line 2
      [1](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=0) # initialize the model
----> [2](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=1) model = kpms.init_model(data, pca=pca, **config())

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:328](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:328), in init_model(***failed resolving arguments***)
    324         pca = utils.fit_pca(
    325             Y_flat, pca_mask, PCA_fitting_num_frames, verbose
    326         )
--> 328     params = init_params(
    329         seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
    330     )
    332 else:

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:110](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:110), in init_params(***failed resolving arguments***)
     80 """
     81 Initialize the parameters of the keypoint SLDS from the
     82 data and hyperparameters.
   (...)
    108     Values for each model parameter.
    109 """
--> 110 params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
    111 params["Cd"] = slds.init_obs_params(
    112     pca, Y_flat, mask, whiten, **ar_hypparams
    113 )

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:101](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:101), in init_params(***failed resolving arguments***)
     98 params["betas"], params["pi"] = init_hdp_transitions(
     99     seed, **trans_hypparams
    100 )
--> 101 params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
    102 return params

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:45](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:45), in init_ar_params(***failed resolving arguments***)
     44 in_axes = (0, na, na, na, na)
---> 45 Ab, Q = jax.vmap(sample_mniw, in_axes)(seeds, nu_0, S_0, M_0, K_0)
     46 return Ab, Q

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:59](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:59), in sample_mniw(***failed resolving arguments***)
     58 def sample_mniw(seed, nu, S, M, K):
---> 59     sigma = sample_invwishart(seed, S, nu)
     60     A = sample_mn(seed, M, sigma, K)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:50](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:50), in sample_invwishart(***failed resolving arguments***)
     47 x = x.at[jnp.triu_indices_from(x, 1)].set(
     48     jr.normal(norm_seed, (n * (n - 1) // 2,))
     49 )
---> 50 R = jnp.linalg.qr(x, "r")
     52 chol = jnp.linalg.cholesky(S)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax\_src\numpy\linalg.py:570](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax/_src/numpy/linalg.py:570), in qr(***failed resolving arguments***)
    569   raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
--> 570 q, r = lax_linalg.qr(a, full_matrices=full_matrices)
    571 if mode == "r":

JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/cuda/cusolver_kernels.cc:45: operation cusolverDnCreate(&handle) failed: cuSolver internal error

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[f:\test\keypoint_moseq\demo_dlc_test.ipynb](file:///F:/test/keypoint_moseq/demo_dlc_test.ipynb) 单元格 8 line 2
      [1](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=0) # initialize the model
----> [2](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=1) model = kpms.init_model(data, pca=pca, **config())

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:328](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:328), in init_model(data, states, params, hypparams, noise_prior, seed, pca, whiten, PCA_fitting_num_frames, anterior_idxs, posterior_idxs, conf_threshold, error_estimator, trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, verbose, exclude_outliers_for_pca, fix_heading, **kwargs)
    321             pca_mask = jnp.logical_and(
    322                 mask, (conf > conf_threshold).all(-1)
    323             )
    324         pca = utils.fit_pca(
    325             Y_flat, pca_mask, PCA_fitting_num_frames, verbose
    326         )
--> 328     params = init_params(
    329         seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
    330     )
    332 else:
    333     params = jax.device_put(params)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:110](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:110), in init_params(seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs)
     77 def init_params(
     78     seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs
     79 ):
     80     """
     81     Initialize the parameters of the keypoint SLDS from the
     82     data and hyperparameters.
   (...)
    108         Values for each model parameter.
    109     """
--> 110     params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
    111     params["Cd"] = slds.init_obs_params(
    112         pca, Y_flat, mask, whiten, **ar_hypparams
    113     )
    114     params["sigmasq"] = jnp.ones(k)

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:101](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:101), in init_params(seed, trans_hypparams, ar_hypparams, **kwargs)
     97 params = {}
     98 params["betas"], params["pi"] = init_hdp_transitions(
     99     seed, **trans_hypparams
    100 )
--> 101 params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
    102 return params

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:45](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:45), in init_ar_params(seed, num_states, nu_0, S_0, M_0, K_0, **kwargs)
     43 seeds = jr.split(seed, num_states)
     44 in_axes = (0, na, na, na, na)
---> 45 Ab, Q = jax.vmap(sample_mniw, in_axes)(seeds, nu_0, S_0, M_0, K_0)
     46 return Ab, Q

    [... skipping hidden 3 frame]

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:59](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:59), in sample_mniw(seed, nu, S, M, K)
     58 def sample_mniw(seed, nu, S, M, K):
---> 59     sigma = sample_invwishart(seed, S, nu)
     60     A = sample_mn(seed, M, sigma, K)
     61     return A, sigma

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:50](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:50), in sample_invwishart(seed, S, nu)
     46 x = jnp.diag(jnp.sqrt(sample_chi2(chi2_seed, nu - jnp.arange(n))))
     47 x = x.at[jnp.triu_indices_from(x, 1)].set(
     48     jr.normal(norm_seed, (n * (n - 1) // 2,))
     49 )
---> 50 R = jnp.linalg.qr(x, "r")
     52 chol = jnp.linalg.cholesky(S)
     54 T = jax.scipy.linalg.solve_triangular(R.T, chol.T, lower=True).T

    [... skipping hidden 19 frame]

File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jaxlib\gpu_solver.py:308](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jaxlib/gpu_solver.py:308), in _orgqr_mhlo(platform, gpu_solver, dtype, a, tau)
    305 assert tau_dims[:-1] == dims[:-2]
    306 k = tau_dims[-1]
--> 308 lwork, opaque = gpu_solver.build_orgqr_descriptor(
    309     np.dtype(dtype), batch, m, n, k)
    311 layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
    312 i32_type = ir.IntegerType.get_signless(32)

RuntimeError: jaxlib/cuda/cusolver_kernels.cc:45: operation cusolverDnCreate(&handle) failed: cuSolver internal error

Additional Information

Operating System

keypoint-moseq version

cuda and cudnn version

image

nvcc --version output

image

nvidia-smi output image

I realized that a similar issue has been reported before #69, but I have not seen an effective solution, and I look forward to your help in resolving this issue, thank you!

calebweinreb commented 9 months ago

I think the problem is that you are using CUDA 12 but the Windows builds of JAX only support CUDA 11 currently.

wula2048 commented 9 months ago

I downloaded CUDA 11.8 from the official NVIDIA website and successfully installed it, but I still encountered the same error. image

calebweinreb commented 9 months ago

Hi,

Sorry I should have asked before if you installed via conda or pip. If you install keypoint-moseq using one of the conda env files, then conda will install its own copy of CUDA that is separate from the system install. In theory that should have worked. I can re-test that installation method on Monday when I have access to a Windows machine. In the mean time, you are welcome to try the pip installation route, which would involve:

  1. Uninstall CUDA 12 from your system (so CUDA 11 is the only one installed)
  2. Install CUDNN 8.X system-wide by downloading from NVIDIA
  3. Delete the current keypoint_moseq env
  4. Follow the instructions for pip install of keypoint_moseq
wula2048 commented 9 months ago

Hello,

Following your instructions, I deleted the original 'keypoint_moseq' environment and installed it using Pip. Here are the steps I took in the terminal:

conda env remove -n keypoint_moseq conda create -n keypoint_moseq python=3.9 pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl pip install keypoint-moseq conda install -c conda-forge pytables conda install numpy=1.24 python -m ipykernel install --user --name=keypoint_moseq python import keypoint_moseq as kpms

Subsequently, I encountered the following error:

(keypoint_moseq) C:\Users\85033>python
Python 3.9.18 | packaged by conda-forge | (main, Aug 30 2023, 03:40:31) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import keypoint_moseq as kpms
2023-10-15 20:49:22.276586: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:454] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: ptxas exited with non-zero error code -1, output: '  If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.

Make sure that CUDA 11.1 and cudnn 8.2 are being used in the environment: nvcc --version output:

(keypoint_moseq) C:\Users\85033>nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:54:10_Pacific_Daylight_Time_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.relgpu_drvr455TC455_06.29190527_0

cudnn version: image

Upon searching for this error on Google, I found a report of a similar issue at 'https://github.com/cloudhan/jax-windows-builder/issues/19'. I realized that this could be an issue with JAX. I followed the instructions at 'https://github.com/cloudhan/jax-windows-builder/issues/19#issuecomment-1514294488' to set up a new environment, but unfortunately, the problem persists, and I'm still getting the same error.

calebweinreb commented 9 months ago

Hmm I'm not sure what to suggest from here. It seems like a jax problem and so might be worth posting an issue elsewhere for following other suggestions from google. If you do manage to create an env where you are able to import jax and run e.g. jax.random.PRNGKey(0) without error, then I am happy help make sure you are able to install the rest of the keypoint-moseq package on top of that.

wula2048 commented 9 months ago

Thank you for your quick response. I will continue to try installing JAX in a new environment.

wula2048 commented 9 months ago

I'm now using keypoint-moseq via WSL2 (cuda11.8,cudnn8.8) and it's working fine for the demo data, so I'll close this issue.