dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
78 stars 28 forks source link

XlaRuntimeError: Internal: CustomCall failed while initializing example dlc model #158

Closed amblypatty closed 3 months ago

amblypatty commented 3 months ago

I have just installed Keypoint-MoSeq on my Windows 10 x64 laptop with an Nvidia 940MX GPU and 16 GB of shared GPU RAM available for use. I installed the program by following the Windows-GPU Installing using conda instructions.

When initially starting a jupyter-lab notebook to run the Project setup tutorial, I received the jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: no kernel image is available for execution on the device error.

I checked if jax could detect my GPU and it was able to do so, printing 'gpu'. So then I followed the rest of the instructions to update my Nvidia GPU driver to the latest version.

Afterwards, I was no longer receiving the error and was able to start creating the demo DLC project. I think I followed all of the prior steps before initializing the example model, but then when I ran the init_model function, I received the XlaRuntimeError again but with a longer traceback.

Here is the line I ran:

# initialize the model
model = kpms.init_model(data, pca=pca, **config())

# optionally modify kappa
# model = kpms.update_hypparams(model, kappa=NUMBER)

And here is the traceback:

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[18], line 2
      1 # initialize the model
----> 2 model = kpms.init_model(data, pca=pca, **config())
      4 # optionally modify kappa
      5 # model = kpms.update_hypparams(model, kappa=NUMBER)

File [~\keypoint-moseq\keypoint_moseq\fitting.py:104](http://localhost:8889/lab/workspaces/auto-E/tree/~/keypoint-moseq/keypoint_moseq/fitting.py#line=103), in init_model(location_aware, allo_hypparams, trans_hypparams, *args, **kwargs)
     97     return allo_keypoint_slds.init_model(
     98         *args,
     99         allo_hypparams=allo_hypparams,
    100         trans_hypparams=trans_hypparams,
    101         **kwargs,
    102     )
    103 else:
--> 104     return keypoint_slds.init_model(
    105         *args, trans_hypparams=trans_hypparams, **kwargs
    106     )

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:320](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py#line=319), in init_model(data, states, params, hypparams, noise_prior, seed, pca, whiten, PCA_fitting_num_frames, anterior_idxs, posterior_idxs, conf_threshold, error_estimator, trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, verbose, exclude_outliers_for_pca, fix_heading, **kwargs)
    317             pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))
    318         pca = utils.fit_pca(Y_flat, pca_mask, PCA_fitting_num_frames, verbose)
--> 320     params = init_params(
    321         seed, pca, Y_flat, mask, **hypparams, whiten=whiten, k=Y.shape[-2]
    322     )
    324 else:
    325     params = jax.device_put(params)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\keypoint_slds\initialize.py:110](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/keypoint_slds/initialize.py#line=109), in init_params(seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs)
     77 def init_params(
     78     seed, pca, Y_flat, mask, trans_hypparams, ar_hypparams, whiten, k, **kwargs
     79 ):
     80     """
     81     Initialize the parameters of the keypoint SLDS from the
     82     data and hyperparameters.
   (...)
    108         Values for each model parameter.
    109     """
--> 110     params = arhmm.init_params(seed, trans_hypparams, ar_hypparams)
    111     params["Cd"] = slds.init_obs_params(pca, Y_flat, mask, whiten, **ar_hypparams)
    112     params["sigmasq"] = jnp.ones(k)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax_moseq\models\arhmm\initialize.py:98](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax_moseq/models/arhmm/initialize.py#line=97), in init_params(seed, trans_hypparams, ar_hypparams, **kwargs)
     77 """
     78 Initialize the parameters of the ARHMM from the
     79 data and hyperparameters.
   (...)
     95     Values for each model parameter.
     96 """
     97 params = {}
---> 98 params["betas"], params["pi"] = init_hdp_transitions(
     99     seed, **trans_hypparams
    100 )
    101 params["Ab"], params["Q"] = init_ar_params(seed, **ar_hypparams)
    102 return params

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax_moseq\utils\transitions.py:342](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax_moseq/utils/transitions.py#line=341), in init_hdp_transitions(seed, num_states, alpha, kappa, gamma, **kwargs)
    314 def init_hdp_transitions(seed, num_states, alpha, kappa, gamma, **kwargs):
    315     """
    316     Initialize the transition parameters of the HDP-HMM.
    317 
   (...)
    340         Initial transition probabilities.
    341     """
--> 342     seeds = jr.split(seed)
    343     betas_init = jr.dirichlet(
    344         seeds[0], jnp.full(num_states, gamma / num_states)
    345     )
    346     pseudo_counts = jnp.zeros((num_states, num_states))

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\random.py:213](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/random.py#line=212), in split(key, num)
    202 """Splits a PRNG key into `num` new keys by adding a leading axis.
    203 
    204 Args:
   (...)
    210   An array-like object of `num` new PRNG keys.
    211 """
    212 key, wrapped = _check_prng_key(key)
--> 213 return _return_prng_keys(wrapped, _split(key, num))

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\random.py:199](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/random.py#line=198), in _split(key, num)
    196 if key.ndim:
    197   raise TypeError("split accepts a single key, but was given a key array of"
    198                   f"shape {key.shape} != (). Use jax.vmap for batching.")
--> 199 return prng.random_split(key, count=num)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:624](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/prng.py#line=623), in random_split(keys, count)
    623 def random_split(keys, count):
--> 624   return random_split_p.bind(keys, count=count)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\core.py:328](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/core.py#line=327), in Primitive.bind(self, *args, **params)
    325 def bind(self, *args, **params):
    326   assert (not config.jax_enable_checks or
    327           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 328   return self.bind_with_trace(find_top_trace(args), args, params)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\core.py:331](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/core.py#line=330), in Primitive.bind_with_trace(self, trace, args, params)
    330 def bind_with_trace(self, trace, args, params):
--> 331   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    332   return map(full_lower, out) if self.multiple_results else full_lower(out)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\core.py:698](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/core.py#line=697), in EvalTrace.process_primitive(self, primitive, tracers, params)
    697 def process_primitive(self, primitive, tracers, params):
--> 698   return primitive.impl(*tracers, **params)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:636](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/prng.py#line=635), in random_split_impl(keys, count)
    634 @random_split_p.def_impl
    635 def random_split_impl(keys, *, count):
--> 636   base_arr = random_split_impl_base(
    637       keys.impl, keys.unsafe_raw_array(), keys.ndim, count=count)
    638   return PRNGKeyArray(keys.impl, base_arr)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:642](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/prng.py#line=641), in random_split_impl_base(impl, base_arr, keys_ndim, count)
    640 def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
    641   split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
--> 642   return split(base_arr)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:641](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/prng.py#line=640), in random_split_impl_base.<locals>.<lambda>(k)
    640 def random_split_impl_base(impl, base_arr, keys_ndim, *, count):
--> 641   split = iterated_vmap_unary(keys_ndim, lambda k: impl.split(k, count))
    642   return split(base_arr)

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\prng.py:1033](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/prng.py#line=1032), in threefry_split(key, num)
   1032 def threefry_split(key: jnp.ndarray, num: int) -> jnp.ndarray:
-> 1033   return _threefry_split(key, int(num))

    [... skipping hidden 6 frame]

File [C:\ProgramData\Anaconda2\envs\keypoint_moseq\lib\site-packages\jax\_src\dispatch.py:878](file:///C:/ProgramData/Anaconda2/envs/keypoint_moseq/lib/site-packages/jax/_src/dispatch.py#line=877), in _execute_compiled(name, compiled, input_handler, output_buffer_counts, result_handler, has_unordered_effects, ordered_effects, kept_var_idx, has_host_callbacks, *args)
    876     runtime_token = None
    877 else:
--> 878   out_flat = compiled.execute(in_flat)
    879 check_special(name, out_flat)
    880 out_bufs = unflatten(out_flat, output_buffer_counts)

XlaRuntimeError: INTERNAL: CustomCall failed: jaxlib/cuda/cuda_prng_kernels.cc:32: operation cudaGetLastError() failed: no kernel image is available for execution on the device

I am not sure why I am experiencing this error again at this point and the troubleshooting docs do not have any further instructions I can follow. Can you help me understand what I can do to get the example model initialization working?

Cheers, Patrick

calebweinreb commented 3 months ago

What version of CUDA and CUDNN do you have? And which CUDA driver is installed?

amblypatty commented 3 months ago

Here is a screenshot of my keypoint_moseq environment build:

# packages in environment at C:\ProgramData\Anaconda2\envs\keypoint_moseq:
#
# Name                    Version                   Build  Channel
absl-py                   2.1.0                    pypi_0    pypi
anyio                     4.4.0                    pypi_0    pypi
argon2-cffi               23.1.0                   pypi_0    pypi
argon2-cffi-bindings      21.2.0                   pypi_0    pypi
arrow                     1.3.0                    pypi_0    pypi
asttokens                 2.4.1                    pypi_0    pypi
async-lru                 2.0.4                    pypi_0    pypi
attrs                     24.2.0                   pypi_0    pypi
av                        12.3.0                   pypi_0    pypi
babel                     2.16.0                   pypi_0    pypi
beautifulsoup4            4.12.3                   pypi_0    pypi
blas                      1.0                         mkl
bleach                    6.1.0                    pypi_0    pypi
blosc                     1.21.6               h85f69ea_0    conda-forge
bokeh                     3.4.3                    pypi_0    pypi
bzip2                     1.0.8                h2bbff1b_6
c-blosc2                  2.15.1               hb461149_0    conda-forge
ca-certificates           2024.7.2             haa95532_0
certifi                   2024.7.4                 pypi_0    pypi
cffi                      1.17.0                   pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
chex                      0.1.6                    pypi_0    pypi
cloudpickle               3.0.0                    pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
colorcet                  3.1.0                    pypi_0    pypi
comm                      0.2.2                    pypi_0    pypi
commentjson               0.9.0                    pypi_0    pypi
contourpy                 1.2.1                    pypi_0    pypi
cuda-nvcc                 12.4.131                      0    nvidia
cuda-version              11.8                 hcce14f8_3
cudatoolkit               11.8.0              h09e9e62_13    conda-forge
cudnn                     8.8.0.121            h84bb9a4_5    conda-forge
cycler                    0.12.1                   pypi_0    pypi
cytoolz                   0.12.3                   pypi_0    pypi
debugpy                   1.8.5                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
defusedxml                0.7.1                    pypi_0    pypi
dm-tree                   0.1.8                    pypi_0    pypi
dynamax                   0.1.4                    pypi_0    pypi
etils                     1.5.2                    pypi_0    pypi
exceptiongroup            1.2.2                    pypi_0    pypi
executing                 2.0.1                    pypi_0    pypi
fastjsonschema            2.20.0                   pypi_0    pypi
fastprogress              1.0.3                    pypi_0    pypi
fonttools                 4.53.1                   pypi_0    pypi
fqdn                      1.5.1                    pypi_0    pypi
fsspec                    2024.6.1                 pypi_0    pypi
gast                      0.6.0                    pypi_0    pypi
h11                       0.14.0                   pypi_0    pypi
h5py                      3.11.0                   pypi_0    pypi
hdf5                      1.14.3          nompi_h2b43c12_105    conda-forge
hdmf                      3.14.3                   pypi_0    pypi
holoviews                 1.19.1                   pypi_0    pypi
httpcore                  1.0.5                    pypi_0    pypi
httpx                     0.27.0                   pypi_0    pypi
idna                      3.7                      pypi_0    pypi
imageio                   2.35.0                   pypi_0    pypi
imageio-ffmpeg            0.5.1                    pypi_0    pypi
importlib-metadata        8.2.0                    pypi_0    pypi
importlib-resources       6.4.2                    pypi_0    pypi
intel-openmp              2023.1.0         h59b6b97_46320
ipykernel                 6.29.5                   pypi_0    pypi
ipython                   8.18.1                   pypi_0    pypi
ipython-genutils          0.2.0                    pypi_0    pypi
isoduration               20.11.0                  pypi_0    pypi
jax                       0.3.22                   pypi_0    pypi
jax-moseq                 0.2.2                    pypi_0    pypi
jaxlib                    0.3.22                   pypi_0    pypi
jaxtyping                 0.2.14                   pypi_0    pypi
jedi                      0.19.1                   pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
json5                     0.9.25                   pypi_0    pypi
jsonpointer               3.0.0                    pypi_0    pypi
jsonschema                4.23.0                   pypi_0    pypi
jsonschema-specifications 2023.12.1                pypi_0    pypi
jupyter-client            8.6.2                    pypi_0    pypi
jupyter-core              5.7.2                    pypi_0    pypi
jupyter-events            0.10.0                   pypi_0    pypi
jupyter-lsp               2.2.5                    pypi_0    pypi
jupyter-server            2.14.2                   pypi_0    pypi
jupyter-server-terminals  0.5.3                    pypi_0    pypi
jupyterlab                4.2.4                    pypi_0    pypi
jupyterlab-pygments       0.3.0                    pypi_0    pypi
jupyterlab-server         2.27.3                   pypi_0    pypi
keypoint-moseq            0.4.7                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
krb5                      1.21.3               hdf4eb48_0    conda-forge
lark-parser               0.7.8                    pypi_0    pypi
libaec                    1.1.3                h63175ca_0    conda-forge
libcurl                   8.9.1                h18fefc2_0    conda-forge
libssh2                   1.11.0               h7dfc565_0    conda-forge
libzlib                   1.3.1                h2466b09_1    conda-forge
libzlib-wapi              1.3.1                h2466b09_1    conda-forge
linkify-it-py             2.0.3                    pypi_0    pypi
llvmlite                  0.43.0                   pypi_0    pypi
lz4-c                     1.9.4                h2bbff1b_1
markdown                  3.6                      pypi_0    pypi
markdown-it-py            3.0.0                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
matplotlib                3.9.2                    pypi_0    pypi
matplotlib-inline         0.1.7                    pypi_0    pypi
mdit-py-plugins           0.4.1                    pypi_0    pypi
mdurl                     0.1.2                    pypi_0    pypi
mistune                   3.0.2                    pypi_0    pypi
mkl                       2023.1.0         h6b88ed4_46358
mkl-service               2.4.0            py39h2bbff1b_1
mkl_fft                   1.3.8            py39h2bbff1b_0
mkl_random                1.2.4            py39h59b6b97_0
nbclient                  0.10.0                   pypi_0    pypi
nbconvert                 7.16.4                   pypi_0    pypi
nbformat                  5.10.4                   pypi_0    pypi
ndx-pose                  0.1.1                    pypi_0    pypi
nest-asyncio              1.6.0                    pypi_0    pypi
networkx                  3.2.1                    pypi_0    pypi
notebook-shim             0.2.4                    pypi_0    pypi
numba                     0.60.0                   pypi_0    pypi
numexpr                   2.8.7            py39h2cd9be0_0
numpy                     1.26.4           py39h055cbcc_0
numpy-base                1.26.4           py39h65a83cf_0
opencv-python-headless    4.10.0.84                pypi_0    pypi
openssl                   3.3.1                h2466b09_2    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
optax                     0.1.7                    pypi_0    pypi
optree                    0.12.1                   pypi_0    pypi
overrides                 7.7.0                    pypi_0    pypi
packaging                 24.1             py39haa95532_0
pandas                    2.2.2                    pypi_0    pypi
pandocfilters             1.5.1                    pypi_0    pypi
panel                     1.4.5                    pypi_0    pypi
param                     2.1.1                    pypi_0    pypi
parso                     0.8.4                    pypi_0    pypi
patsy                     0.5.6                    pypi_0    pypi
pillow                    10.4.0                   pypi_0    pypi
pip                       24.2             py39haa95532_0
platformdirs              4.2.2                    pypi_0    pypi
plotly                    5.23.0                   pypi_0    pypi
prometheus-client         0.20.0                   pypi_0    pypi
prompt-toolkit            3.0.47                   pypi_0    pypi
psutil                    6.0.0                    pypi_0    pypi
pure-eval                 0.2.3                    pypi_0    pypi
py-cpuinfo                9.0.0            py39haa95532_0
pycparser                 2.22                     pypi_0    pypi
pygments                  2.18.0                   pypi_0    pypi
pynwb                     2.8.1                    pypi_0    pypi
pyparsing                 3.1.2                    pypi_0    pypi
pytables                  3.9.2            py39h2499b97_3    conda-forge
python                    3.9.19               h1aa4202_1
python-dateutil           2.9.0.post0              pypi_0    pypi
python-json-logger        2.0.7                    pypi_0    pypi
python_abi                3.9                      2_cp39    conda-forge
pytz                      2024.1                   pypi_0    pypi
pyviz-comms               3.0.3                    pypi_0    pypi
pywin32                   306                      pypi_0    pypi
pywinpty                  2.0.13                   pypi_0    pypi
pyyaml                    6.0.2                    pypi_0    pypi
pyzmq                     26.1.0                   pypi_0    pypi
referencing               0.35.1                   pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
rfc3339-validator         0.1.4                    pypi_0    pypi
rfc3986-validator         0.1.1                    pypi_0    pypi
rpds-py                   0.20.0                   pypi_0    pypi
ruamel-yaml               0.18.6                   pypi_0    pypi
ruamel-yaml-clib          0.2.8                    pypi_0    pypi
scikit-learn              1.5.1                    pypi_0    pypi
scipy                     1.11.3                   pypi_0    pypi
seaborn                   0.13.0                   pypi_0    pypi
send2trash                1.8.3                    pypi_0    pypi
setuptools                72.1.0           py39haa95532_0
simplejson                3.19.3                   pypi_0    pypi
six                       1.16.0                   pypi_0    pypi
sleap-io                  0.1.7                    pypi_0    pypi
snappy                    1.2.1                hcdb6601_0
sniffio                   1.3.1                    pypi_0    pypi
soupsieve                 2.6                      pypi_0    pypi
sqlite                    3.45.3               h2bbff1b_0
stack-data                0.6.3                    pypi_0    pypi
statsmodels               0.14.2                   pypi_0    pypi
tabulate                  0.9.0                    pypi_0    pypi
tbb                       2021.8.0             h59b6b97_0
tenacity                  9.0.0                    pypi_0    pypi
tensorflow-probability    0.19.0                   pypi_0    pypi
terminado                 0.18.1                   pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
tinycss2                  1.3.0                    pypi_0    pypi
tomli                     2.0.1                    pypi_0    pypi
toolz                     0.12.1                   pypi_0    pypi
tornado                   6.4.1                    pypi_0    pypi
tqdm                      4.66.5                   pypi_0    pypi
traitlets                 5.14.3                   pypi_0    pypi
typeguard                 4.3.0                    pypi_0    pypi
types-python-dateutil     2.9.0.20240316           pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
uc-micro-py               1.0.3                    pypi_0    pypi
ucrt                      10.0.20348.0         haa95532_0
uri-template              1.3.0                    pypi_0    pypi
urllib3                   2.2.2                    pypi_0    pypi
vc                        14.40                h2eaa2aa_0
vc14_runtime              14.40.33810         ha82c5b3_20    conda-forge
vidio                     0.0.4                    pypi_0    pypi
vs2015_runtime            14.40.33810         h3bf8584_20    conda-forge
wcwidth                   0.2.13                   pypi_0    pypi
webcolors                 24.8.0                   pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
websocket-client          1.8.0                    pypi_0    pypi
wheel                     0.43.0           py39haa95532_0
xyzservices               2024.6.0                 pypi_0    pypi
zipp                      3.20.0                   pypi_0    pypi
zlib-ng                   2.2.1                he0c23c2_0    conda-forge
zstd                      1.5.6                h0ea2cb4_0    conda-forge

which shows CUDA 11.8, CUDNN 8.8.0.121 and I have CUDA Nvidia GeForce Game Ready driver 560.81.

calebweinreb commented 3 months ago

I think the issue is that your GPU is too old (see this post https://github.com/google/jax/issues/5723#issuecomment-1823132970). They may have added support in the most recent version of jax but that version is not available on Windows. Also your GPU only has 2GB of VRAM (https://www.techpowerup.com/gpu-specs/geforce-940mx.c2845) so it wouldn't be useful anyway (in general, built in graphics GPUs on laptops are not the type that are useful for these types of ML tasks). So I would recommend using google colab, updating hardware, or using the CPU version.