dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
68 stars 28 forks source link

jaxlib not installed #40

Closed vickerse1 closed 1 year ago

vickerse1 commented 1 year ago

Hi,

With the Windows GPU version installed, I keep getting the following error:


jax requires jaxlib to be installed

....when I run


import keypoint_moseq as kpms project_dir = 'demo_project' config = lambda: kpms.load_config(project_dir)

I uninstalled jaxlib in a previous attempt at the environment - how do I reinstall it, and why didn't the uninstall only affect the previous environment?

Thanks,

Evan

calebweinreb commented 1 year ago

I don't know but I think the simplest thing to do would be start over and follow the docs for Windows install as closely as possible. You mentioned in another issue getting some kind of "kernel image not available" warning and then updating to the nightly build of jax. I think that warning might be benign, so the nightly update is probably unnecessary, and it may be causing these downstream issues... in fact it probably is, since it doesn't use the Windows-specific builds on https://whls.blob.core.windows.net/, e.g.

pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl
GergelyTuri commented 1 year ago

Something similar comes up with the demo Colab notebook as well.

During the pip install steps:

! pip install --upgrade "jax[cuda]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
! pip install keypoint-moseq

i see this error:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.7 requires jax>=0.4.6, but you have jax 0.3.22 which is incompatible.
flax 0.6.9 requires jax>=0.4.2, but you have jax 0.3.22 which is incompatible.
orbax-checkpoint 0.2.1 requires jax>=0.4.8, but you have jax 0.3.22 which is incompatible.

then at the import keypoint_moseq as kpms line i get this trace:

AttributeError                            Traceback (most recent call last)
[<ipython-input-2-5b3c05052691>](https://localhost:8080/#) in <cell line: 1>()
----> 1 import keypoint_moseq as kpms
      2 
      3 project_dir = '/content/drive/Shareddrives/Turi_lab/Data/Context_project/moseq_example/'
      4 config = lambda: kpms.load_config(project_dir)

11 frames
[/usr/local/lib/python3.10/dist-packages/jaxtyping/array_types.py](https://localhost:8080/#) in <module>
    665         PRNGKeyArray = Key[jax.Array, "2"]
    666     Scalar = Shaped[jax.Array, ""]
--> 667     ScalarLike = Shaped[jax.typing.ArrayLike, ""]

AttributeError: module 'jax' has no attribute 'typing'

any idea whether a higher version of jax[cuda] would be helpful here?

calebweinreb commented 1 year ago

Hey sorry about that. This is related to an update of the jaxtyping library. I just uploaded a new version (0.1.2) that pins jaxtyping to version 0.2.14 which should solve the problem.... At least the colab worked for me when I tried it just now.

GergelyTuri commented 1 year ago

Thanks, it worked.

vickerse1 commented 1 year ago

Hi Caleb,

I'm still unsure about my problem.

The kernel still dies in the Windows GPU version on the PCA fitting step, despite the fact that the GPU is being used and resources do not appear to be taxed....

On the installation page it says I should be using Cuda 11.1, but I'm using 11.3. Is this the problem?

What should I do about this - download 11.1 and set the path, or open the env yml and change the version to 11.3, which I already have, and use? Or, just download 11.1 (can I do this without installing it as the main version?)

Or, should I try the WSL install? I know you said you don't support that one, but do you know which version of CUDA should work in Linux?

We were unable to get the Linux version to work with GPU running CUDA 11.0....(network GPU workstation)....we are going to install CUDA 12.0 on that system and try again.

Thanks,

EVan


From: Caleb Weinreb @.> Sent: Friday, May 12, 2023 2:16 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)

Hey sorry about that. This is related to an update of the jaxtyping library. I just uploaded a new version (0.1.2) that pins jaxtyping to version 0.2.14 which should solve the problem.... At least the colab worked for me when I tried it just now.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1546310525__;Iw!!C5qS4YX3!Gw5RGV5sZrMFaij9O8fIL11S1WyzekeNGLfpV4O15uCeosEyMvF5oOQ5AvNgC5dmnUm7kKg-hXXziAe6D38AQPyewufV$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZVQMMHK7HHIF3GX2QLXF2SB3ANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!Gw5RGV5sZrMFaij9O8fIL11S1WyzekeNGLfpV4O15uCeosEyMvF5oOQ5AvNgC5dmnUm7kKg-hXXziAe6D38AQCD--HyD$. You are receiving this because you authored the thread.Message ID: @.***>

calebweinreb commented 1 year ago

Hi Evan,

Regarding the CUDA version: it has to be 11.1 if you install keypoint-moseq using pip. If you install with conda (i.e. conda env create -f conda_envs\environment.win64_gpu.yml, then conda will build its own version of CUDA for that env with the right versions of everything.

In terms of the kernel dying, if it is limited to the PCA step and doesn't occur during any of the subsequent steps and only happens when jax is installed with GPU support but not CPU; then the easiest thing might be to just run PCA in the CPU env, restart the kernel, and then pick up after PCA in the GPU env...

We can also try to dive deeper into what is going on. The code below is what gets when you fit PCA. Does it cause the kernel to restart? If so, which line?

from jax_moseq.models.keypoint_slds import preprocess_for_pca
from sklearn.decomposition import PCA
import numpy as np

Y = data['Y']
mask = data['mask']
conf = data['conf']

cfg = config()
anterior_idxs = cfg['anterior_idxs']
posterior_idxs = cfg['posterior_idxs']
conf_threshold = cfg['conf_threshold']
fix_heading = cfg['conf_threshold']
verbose = cfg['conf_threshold']
PCA_fitting_num_frames = cfg['PCA_fitting_num_frames']

# jax_moseq.models.keypoint_slds.fit_pca
Y_flat = preprocess_for_pca(
    Y, anterior_idxs, posterior_idxs, conf, 
    conf_threshold, fix_heading, verbose)[0]
if conf is None: pca_mask = mask
else: pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))

# jax_moseq.models.utils.fit_pca
Y_flat = Y_flat[mask > 0]
N = Y_flat.shape[0]
N_sample = min(PCA_fitting_num_frames, N)
sample = np.random.choice(N, N_sample, replace=False)
Y_sample = np.array(Y_flat)[sample]
pca = PCA().fit(Y_sample)
vickerse1 commented 1 year ago

Hi Caleb,

So it doesn't matter which version is indicated in the env yml file with the conda install?

I made a new Windows GPU environment, now I'm getting this error with the first cell in the notebook:

import keypoint_moseq as kpms project_dir = 'demo_project' config = lambda: kpms.load_config(project_dir)

gives:

[I 2023-05-15 12:52:00.520 ServerApp] Connecting to kernel 943a127f-f9f3-41ae-a347-4c51c5a7854d. [I 2023-05-15 12:52:11.220 ServerApp] Kernel started: 2a475891-79a3-4785-9a87-27ffc761c3ff [I 2023-05-15 12:52:11.222 ServerApp] Kernel shutdown: 943a127f-f9f3-41ae-a347-4c51c5a7854d [I 2023-05-15 12:52:13.252 ServerApp] Connecting to kernel 2a475891-79a3-4785-9a87-27ffc761c3ff. [I 2023-05-15 12:52:13.278 ServerApp] Connecting to kernel 2a475891-79a3-4785-9a87-27ffc761c3ff. 2023-05-15 12:52:23.338874: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc:61] cuLinkAddData fails. This is usually caused by stale driver version. 2023-05-15 12:52:23.339854: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1369] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.

How do I implement its XLA_FLAGS suggestion? I entered this literally and it gave an error.

Thanks,

Evan


From: Caleb Weinreb @.> Sent: Monday, May 15, 2023 12:32 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)

Hi Evan,

Regarding the CUDA version: it has to be 11.1 if you install keypoint-moseq using pip. If you install with conda (i.e. conda env create -f conda_envs\environment.win64_gpu.yml, then conda will build its own version of CUDA for that env with the right versions of everything.

In terms of the kernel dying, if it is limited to the PCA step and doesn't occur during any of the subsequent steps and only happens when jax is installed with GPU support but not CPU; then the easiest thing might be to just run PCA in the CPU env, restart the kernel, and then pick up after PCA in the GPU env...

We can also try to dive deeper into what is going on. The code below is what gets when you fit PCA. Does it cause the kernel to restart? If so, which line?

from jax_moseq.models.keypoint_slds import preprocess_for_pca from sklearn.decomposition import PCA import numpy as np

Y = data['Y'] mask = data['mask'] conf = data['conf']

cfg = config() anterior_idxs = cfg['anterior_idxs'] posterior_idxs = cfg['posterior_idxs'] conf_threshold = cfg['conf_threshold'] fix_heading = cfg['conf_threshold'] verbose = cfg['conf_threshold'] PCA_fitting_num_frames = cfg['PCA_fitting_num_frames']

jax_moseq.models.keypoint_slds.fit_pca

Y_flat = preprocess_for_pca( Y, anterior_idxs, posterior_idxs, conf, conf_threshold, fix_heading, verbose)[0] if conf is None: pca_mask = mask else: pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))

jax_moseq.models.utils.fit_pca

Y_flat = Y_flat[mask > 0] N = Y_flat.shape[0] N_sample = min(PCA_fitting_num_frames, N) sample = np.random.choice(N, N_sample, replace=False) Y_sample = np.array(Y_flat)[sample] pca = PCA().fit(Y_sample)

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1548454610__;Iw!!C5qS4YX3!AczCgr6dRlt0EvSA2LsVoyEojkmFSD60tfrOJ5r8S7z2wUMUPug18vpZtid-tlj0p9iJ_WJx3FRIhoHkkJxGEjcu2GTG$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZURYE7QFTXWWL6YJKDXGKAGBANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!AczCgr6dRlt0EvSA2LsVoyEojkmFSD60tfrOJ5r8S7z2wUMUPug18vpZtid-tlj0p9iJ_WJx3FRIhoHkkJxGEgQsWpwU$. You are receiving this because you authored the thread.Message ID: @.***>

calebweinreb commented 1 year ago

It definitely does matter which CUDA is indicated in the .yml file. But it doesn't matter which CUDA you have installed globally since conda will install its own.

vickerse1 commented 1 year ago

Hi Caleb,

OK, I now have the GPU version installed on Linux with CUDA 12, and everything runs fine but I get the following error when I try to load the DeepLabCut data (copied all configs, csv files, videos, and folder structure from working local CPU Windows version):


load data from DeepLabCut

dlc_results = 'dlc_project/pose_est_csv' # can be a file, a directory, or a list of files coordinates, confidences = kpms.load_deeplabcut_results(dlc_results)

format data for modeling

data, labels = kpms.format_data(coordinates, confidences=confidences, **config())

ValueError Traceback (most recent call last) Cell In[7], line 3 1 # load data from DeepLabCut 2 dlc_results = 'dlc_project/pose_est_csv' # can be a file, a directory, or a list of files ----> 3 coordinates, confidences = kpms.load_deeplabcut_results(dlc_results) 5 # format data for modeling 6 data, labels = kpms.format_data(coordinates, confidences=confidences, **config())

ValueError: too many values to unpack (expected 2)

I tried running it with just one input csv file, and received the same error (currently running with 30 csv files and 30 videos).

I'm not sure why I'm getting a different result than I did with my local run with the same data.

-Evan


From: Caleb Weinreb @.> Sent: Monday, May 15, 2023 12:59 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)

It definitely does matter which CUDA is indicated in the .yml file. But it doesn't matter which CUDA you have installed globally since conda will install its own.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1548491875__;Iw!!C5qS4YX3!CZm54HXAIGVTfQ9g6Ood8sMqHXL3ggEhBqBea0vdWuN-0S2WmX3v5ZM8gFmb_SNKIhpkZO2rlBPm6dOdiucE7ccov31I$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZRSSGO5AJ5IWEGXUL3XGKDLRANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!CZm54HXAIGVTfQ9g6Ood8sMqHXL3ggEhBqBea0vdWuN-0S2WmX3v5ZM8gFmb_SNKIhpkZO2rlBPm6dOdiucE7U3uzVYS$. You are receiving this because you authored the thread.Message ID: @.***>

calebweinreb commented 1 year ago

That line has to be updated as of the latest version (see updated tutorial)

coordinates, confidences, bodyparts = kpms.load_deeplabcut_results(dlc_results)