Hi all ,
I am trying to use keypoint-moseq and after installing the virtual environment according to the conda method, I am trying to validate the process using the demo data you provided.But getting the following error when running model = kpms.init_model(data, pca=pca, **config())
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File [d:\anaconda\envs\keypoint_moseq\lib\runpy.py:197](file:///D:/anaconda/envs/keypoint_moseq/lib/runpy.py:197), in _run_module_as_main(***failed resolving arguments***)
196 sys.argv[0] = mod_spec.origin
--> 197 return _run_code(code, main_globals, None,
198 "__main__", mod_spec)
File [d:\anaconda\envs\keypoint_moseq\lib\runpy.py:87](file:///D:/anaconda/envs/keypoint_moseq/lib/runpy.py:87), in _run_code(***failed resolving arguments***)
80 run_globals.update(__name__ = mod_name,
81 __file__ = fname,
82 __cached__ = cached,
(...)
85 __package__ = pkg_name,
86 __spec__ = mod_spec)
---> 87 exec(code, run_globals)
88 return run_globals
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel_launcher.py:17](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel_launcher.py:17)
15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\traitlets\config\application.py:1043](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/traitlets/config/application.py:1043), in Application.launch_instance(***failed resolving arguments***)
1042 app.initialize(argv)
-> 1043 app.start()
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelapp.py:736](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelapp.py:736), in IPKernelApp.start(***failed resolving arguments***)
735 try:
--> 736 self.io_loop.start()
737 except KeyboardInterrupt:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\tornado\platform\asyncio.py:195](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/tornado/platform/asyncio.py:195), in BaseAsyncIOLoop.start(***failed resolving arguments***)
194 def start(self) -> None:
--> 195 self.asyncio_loop.run_forever()
File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\base_events.py:601](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/base_events.py:601), in BaseEventLoop.run_forever(***failed resolving arguments***)
600 while True:
--> 601 self._run_once()
602 if self._stopping:
File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\base_events.py:1905](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/base_events.py:1905), in BaseEventLoop._run_once(***failed resolving arguments***)
1904 else:
-> 1905 handle._run()
1906 handle = None
File [d:\anaconda\envs\keypoint_moseq\lib\asyncio\events.py:80](file:///D:/anaconda/envs/keypoint_moseq/lib/asyncio/events.py:80), in Handle._run(***failed resolving arguments***)
79 try:
---> 80 self._context.run(self._callback, *self._args)
81 except (SystemExit, KeyboardInterrupt):
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:516](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:516), in Kernel.dispatch_queue(***failed resolving arguments***)
515 try:
--> 516 await self.process_one()
517 except Exception:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:505](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:505), in Kernel.process_one(***failed resolving arguments***)
504 return None
--> 505 await dispatch(*args)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:412](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:412), in Kernel.dispatch_shell(***failed resolving arguments***)
411 if inspect.isawaitable(result):
--> 412 await result
413 except Exception:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\kernelbase.py:740](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/kernelbase.py:740), in Kernel.execute_request(***failed resolving arguments***)
739 if inspect.isawaitable(reply_content):
--> 740 reply_content = await reply_content
742 # Flush output before sending the reply.
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\ipkernel.py:422](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/ipkernel.py:422), in IPythonKernel.do_execute(***failed resolving arguments***)
421 if with_cell_id:
--> 422 res = shell.run_cell(
423 code,
424 store_history=store_history,
425 silent=silent,
426 cell_id=cell_id,
427 )
428 else:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\ipykernel\zmqshell.py:546](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/ipykernel/zmqshell.py:546), in ZMQInteractiveShell.run_cell(***failed resolving arguments***)
545 self._last_traceback = None
--> 546 return super().run_cell(*args, **kwargs)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3024](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3024), in InteractiveShell.run_cell(***failed resolving arguments***)
3023 try:
-> 3024 result = self._run_cell(
3025 raw_cell, store_history, silent, shell_futures, cell_id
3026 )
3027 finally:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3079](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3079), in InteractiveShell._run_cell(***failed resolving arguments***)
3078 try:
-> 3079 result = runner(coro)
3080 except BaseException as e:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\async_helpers.py:129](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/async_helpers.py:129), in _pseudo_sync_runner(***failed resolving arguments***)
128 try:
--> 129 coro.send(None)
130 except StopIteration as exc:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3284](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3284), in InteractiveShell.run_cell_async(***failed resolving arguments***)
3281 interactivity = "none" if silent else self.ast_node_interactivity
-> 3284 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
3285 interactivity=interactivity, compiler=compiler, result=result)
3287 self.last_execution_succeeded = not has_raised
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3466](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3466), in InteractiveShell.run_ast_nodes(***failed resolving arguments***)
3465 asy = compare(code)
-> 3466 if await self.run_code(code, result, async_=asy):
3467 return True
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\IPython\core\interactiveshell.py:3526](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/IPython/core/interactiveshell.py:3526), in InteractiveShell.run_code(***failed resolving arguments***)
3525 else:
-> 3526 exec(code_obj, self.user_global_ns, self.user_ns)
3527 finally:
3528 # Reset our crash handler in place
[f:\test\keypoint_moseq\demo_dlc_test.ipynb](file:///F:/test/keypoint_moseq/demo_dlc_test.ipynb) 单元格 8 line 2
[1](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=0) # initialize the model
----> [2](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=1) model = kpms.init_model(data, pca=pca, **config())
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:328](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:328), in init_model(***failed resolving arguments***)
324 pca = utils.fit_pca(
325 Y_flat, pca_mask, PCA_fitting_num_frames, verbose
326 )
--> 328 params = init_params(
329 seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
330 )
332 else:
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:110](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:110), in init_params(***failed resolving arguments***)
80 """
81 Initialize the parameters of the keypoint SLDS from the
82 data and hyperparameters.
(...)
108 Values for each model parameter.
109 """
--> 110 params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
111 params["Cd"] = slds.init_obs_params(
112 pca, Y_flat, mask, whiten, **ar_hypparams
113 )
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:101](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:101), in init_params(***failed resolving arguments***)
98 params["betas"], params["pi"] = init_hdp_transitions(
99 seed, **trans_hypparams
100 )
--> 101 params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
102 return params
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:45](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:45), in init_ar_params(***failed resolving arguments***)
44 in_axes = (0, na, na, na, na)
---> 45 Ab, Q = jax.vmap(sample_mniw, in_axes)(seeds, nu_0, S_0, M_0, K_0)
46 return Ab, Q
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:59](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:59), in sample_mniw(***failed resolving arguments***)
58 def sample_mniw(seed, nu, S, M, K):
---> 59 sigma = sample_invwishart(seed, S, nu)
60 A = sample_mn(seed, M, sigma, K)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:50](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:50), in sample_invwishart(***failed resolving arguments***)
47 x = x.at[jnp.triu_indices_from(x, 1)].set(
48 jr.normal(norm_seed, (n * (n - 1) // 2,))
49 )
---> 50 R = jnp.linalg.qr(x, "r")
52 chol = jnp.linalg.cholesky(S)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax\_src\numpy\linalg.py:570](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax/_src/numpy/linalg.py:570), in qr(***failed resolving arguments***)
569 raise ValueError(f"Unsupported QR decomposition mode '{mode}'")
--> 570 q, r = lax_linalg.qr(a, full_matrices=full_matrices)
571 if mode == "r":
JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/cuda/cusolver_kernels.cc:45: operation cusolverDnCreate(&handle) failed: cuSolver internal error
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
[f:\test\keypoint_moseq\demo_dlc_test.ipynb](file:///F:/test/keypoint_moseq/demo_dlc_test.ipynb) 单元格 8 line 2
[1](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=0) # initialize the model
----> [2](vscode-notebook-cell:/f%3A/test/keypoint_moseq/demo_dlc_test.ipynb#X12sZmlsZQ%3D%3D?line=1) model = kpms.init_model(data, pca=pca, **config())
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:328](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:328), in init_model(data, states, params, hypparams, noise_prior, seed, pca, whiten, PCA_fitting_num_frames, anterior_idxs, posterior_idxs, conf_threshold, error_estimator, trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, verbose, exclude_outliers_for_pca, fix_heading, **kwargs)
321 pca_mask = jnp.logical_and(
322 mask, (conf > conf_threshold).all(-1)
323 )
324 pca = utils.fit_pca(
325 Y_flat, pca_mask, PCA_fitting_num_frames, verbose
326 )
--> 328 params = init_params(
329 seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
330 )
332 else:
333 params = jax.device_put(params)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:110](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py:110), in init_params(seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs)
77 def init_params(
78 seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs
79 ):
80 """
81 Initialize the parameters of the keypoint SLDS from the
82 data and hyperparameters.
(...)
108 Values for each model parameter.
109 """
--> 110 params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
111 params["Cd"] = slds.init_obs_params(
112 pca, Y_flat, mask, whiten, **ar_hypparams
113 )
114 params["sigmasq"] = jnp.ones(k)
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:101](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:101), in init_params(seed, trans_hypparams, ar_hypparams, **kwargs)
97 params = {}
98 params["betas"], params["pi"] = init_hdp_transitions(
99 seed, **trans_hypparams
100 )
--> 101 params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
102 return params
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:45](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py:45), in init_ar_params(seed, num_states, nu_0, S_0, M_0, K_0, **kwargs)
43 seeds = jr.split(seed, num_states)
44 in_axes = (0, na, na, na, na)
---> 45 Ab, Q = jax.vmap(sample_mniw, in_axes)(seeds, nu_0, S_0, M_0, K_0)
46 return Ab, Q
[... skipping hidden 3 frame]
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:59](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:59), in sample_mniw(seed, nu, S, M, K)
58 def sample_mniw(seed, nu, S, M, K):
---> 59 sigma = sample_invwishart(seed, S, nu)
60 A = sample_mn(seed, M, sigma, K)
61 return A, sigma
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\distributions.py:50](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/distributions.py:50), in sample_invwishart(seed, S, nu)
46 x = jnp.diag(jnp.sqrt(sample_chi2(chi2_seed, nu - jnp.arange(n))))
47 x = x.at[jnp.triu_indices_from(x, 1)].set(
48 jr.normal(norm_seed, (n * (n - 1) // 2,))
49 )
---> 50 R = jnp.linalg.qr(x, "r")
52 chol = jnp.linalg.cholesky(S)
54 T = jax.scipy.linalg.solve_triangular(R.T, chol.T, lower=True).T
[... skipping hidden 19 frame]
File [d:\anaconda\envs\keypoint_moseq\lib\site-packages\jaxlib\gpu_solver.py:308](file:///D:/anaconda/envs/keypoint_moseq/lib/site-packages/jaxlib/gpu_solver.py:308), in _orgqr_mhlo(platform, gpu_solver, dtype, a, tau)
305 assert tau_dims[:-1] == dims[:-2]
306 k = tau_dims[-1]
--> 308 lwork, opaque = gpu_solver.build_orgqr_descriptor(
309 np.dtype(dtype), batch, m, n, k)
311 layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
312 i32_type = ir.IntegerType.get_signless(32)
RuntimeError: jaxlib/cuda/cusolver_kernels.cc:45: operation cusolverDnCreate(&handle) failed: cuSolver internal error
Additional Information
Operating System
windows10
GPU:NVIDIA GeForce RTX 4070
RAM:16GB
keypoint-moseq version
keypoint-moseq 0.2.5
cuda and cudnn version
nvcc --version output
nvidia-smi output
I realized that a similar issue has been reported before #69, but I have not seen an effective solution, and I look forward to your help in resolving this issue, thank you!
Hi all , I am trying to use keypoint-moseq and after installing the virtual environment according to the conda method, I am trying to validate the process using the demo data you provided.But getting the following error when running
model = kpms.init_model(data, pca=pca, **config())
Additional Information
Operating System
keypoint-moseq version
cuda and cudnn version
nvcc --version output
nvidia-smi output
I realized that a similar issue has been reported before #69, but I have not seen an effective solution, and I look forward to your help in resolving this issue, thank you!