Closed vickerse1 closed 1 year ago
I don't know but I think the simplest thing to do would be start over and follow the docs for Windows install as closely as possible. You mentioned in another issue getting some kind of "kernel image not available" warning and then updating to the nightly build of jax. I think that warning might be benign, so the nightly update is probably unnecessary, and it may be causing these downstream issues... in fact it probably is, since it doesn't use the Windows-specific builds on https://whls.blob.core.windows.net/, e.g.
pip install jax==0.3.22 https://whls.blob.core.windows.net/unstable/cuda111/jaxlib-0.3.22+cuda11.cudnn82-cp39-cp39-win_amd64.whl
Something similar comes up with the demo Colab notebook as well.
During the pip install steps:
! pip install --upgrade "jax[cuda]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
! pip install keypoint-moseq
i see this error:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.7 requires jax>=0.4.6, but you have jax 0.3.22 which is incompatible.
flax 0.6.9 requires jax>=0.4.2, but you have jax 0.3.22 which is incompatible.
orbax-checkpoint 0.2.1 requires jax>=0.4.8, but you have jax 0.3.22 which is incompatible.
then at the import keypoint_moseq as kpms
line i get this trace:
AttributeError Traceback (most recent call last)
[<ipython-input-2-5b3c05052691>](https://localhost:8080/#) in <cell line: 1>()
----> 1 import keypoint_moseq as kpms
2
3 project_dir = '/content/drive/Shareddrives/Turi_lab/Data/Context_project/moseq_example/'
4 config = lambda: kpms.load_config(project_dir)
11 frames
[/usr/local/lib/python3.10/dist-packages/jaxtyping/array_types.py](https://localhost:8080/#) in <module>
665 PRNGKeyArray = Key[jax.Array, "2"]
666 Scalar = Shaped[jax.Array, ""]
--> 667 ScalarLike = Shaped[jax.typing.ArrayLike, ""]
AttributeError: module 'jax' has no attribute 'typing'
any idea whether a higher version of jax[cuda]
would be helpful here?
Hey sorry about that. This is related to an update of the jaxtyping library. I just uploaded a new version (0.1.2) that pins jaxtyping to version 0.2.14 which should solve the problem.... At least the colab worked for me when I tried it just now.
Thanks, it worked.
Hi Caleb,
I'm still unsure about my problem.
The kernel still dies in the Windows GPU version on the PCA fitting step, despite the fact that the GPU is being used and resources do not appear to be taxed....
On the installation page it says I should be using Cuda 11.1, but I'm using 11.3. Is this the problem?
What should I do about this - download 11.1 and set the path, or open the env yml and change the version to 11.3, which I already have, and use? Or, just download 11.1 (can I do this without installing it as the main version?)
Or, should I try the WSL install? I know you said you don't support that one, but do you know which version of CUDA should work in Linux?
We were unable to get the Linux version to work with GPU running CUDA 11.0....(network GPU workstation)....we are going to install CUDA 12.0 on that system and try again.
Thanks,
EVan
From: Caleb Weinreb @.> Sent: Friday, May 12, 2023 2:16 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)
Hey sorry about that. This is related to an update of the jaxtyping library. I just uploaded a new version (0.1.2) that pins jaxtyping to version 0.2.14 which should solve the problem.... At least the colab worked for me when I tried it just now.
— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1546310525__;Iw!!C5qS4YX3!Gw5RGV5sZrMFaij9O8fIL11S1WyzekeNGLfpV4O15uCeosEyMvF5oOQ5AvNgC5dmnUm7kKg-hXXziAe6D38AQPyewufV$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZVQMMHK7HHIF3GX2QLXF2SB3ANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!Gw5RGV5sZrMFaij9O8fIL11S1WyzekeNGLfpV4O15uCeosEyMvF5oOQ5AvNgC5dmnUm7kKg-hXXziAe6D38AQCD--HyD$. You are receiving this because you authored the thread.Message ID: @.***>
Hi Evan,
Regarding the CUDA version: it has to be 11.1 if you install keypoint-moseq using pip. If you install with conda (i.e. conda env create -f conda_envs\environment.win64_gpu.yml
, then conda will build its own version of CUDA for that env with the right versions of everything.
In terms of the kernel dying, if it is limited to the PCA step and doesn't occur during any of the subsequent steps and only happens when jax is installed with GPU support but not CPU; then the easiest thing might be to just run PCA in the CPU env, restart the kernel, and then pick up after PCA in the GPU env...
We can also try to dive deeper into what is going on. The code below is what gets when you fit PCA. Does it cause the kernel to restart? If so, which line?
from jax_moseq.models.keypoint_slds import preprocess_for_pca
from sklearn.decomposition import PCA
import numpy as np
Y = data['Y']
mask = data['mask']
conf = data['conf']
cfg = config()
anterior_idxs = cfg['anterior_idxs']
posterior_idxs = cfg['posterior_idxs']
conf_threshold = cfg['conf_threshold']
fix_heading = cfg['conf_threshold']
verbose = cfg['conf_threshold']
PCA_fitting_num_frames = cfg['PCA_fitting_num_frames']
# jax_moseq.models.keypoint_slds.fit_pca
Y_flat = preprocess_for_pca(
Y, anterior_idxs, posterior_idxs, conf,
conf_threshold, fix_heading, verbose)[0]
if conf is None: pca_mask = mask
else: pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))
# jax_moseq.models.utils.fit_pca
Y_flat = Y_flat[mask > 0]
N = Y_flat.shape[0]
N_sample = min(PCA_fitting_num_frames, N)
sample = np.random.choice(N, N_sample, replace=False)
Y_sample = np.array(Y_flat)[sample]
pca = PCA().fit(Y_sample)
Hi Caleb,
So it doesn't matter which version is indicated in the env yml file with the conda install?
I made a new Windows GPU environment, now I'm getting this error with the first cell in the notebook:
import keypoint_moseq as kpms project_dir = 'demo_project' config = lambda: kpms.load_config(project_dir)
gives:
[I 2023-05-15 12:52:00.520 ServerApp] Connecting to kernel 943a127f-f9f3-41ae-a347-4c51c5a7854d. [I 2023-05-15 12:52:11.220 ServerApp] Kernel started: 2a475891-79a3-4785-9a87-27ffc761c3ff [I 2023-05-15 12:52:11.222 ServerApp] Kernel shutdown: 943a127f-f9f3-41ae-a347-4c51c5a7854d [I 2023-05-15 12:52:13.252 ServerApp] Connecting to kernel 2a475891-79a3-4785-9a87-27ffc761c3ff. [I 2023-05-15 12:52:13.278 ServerApp] Connecting to kernel 2a475891-79a3-4785-9a87-27ffc761c3ff. 2023-05-15 12:52:23.338874: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_asm_compiler.cc:61] cuLinkAddData fails. This is usually caused by stale driver version. 2023-05-15 12:52:23.339854: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:1369] The CUDA linking API did not work. Please use XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 to bypass it, but expect to get longer compilation time due to the lack of multi-threading.
How do I implement its XLA_FLAGS suggestion? I entered this literally and it gave an error.
Thanks,
Evan
From: Caleb Weinreb @.> Sent: Monday, May 15, 2023 12:32 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)
Hi Evan,
Regarding the CUDA version: it has to be 11.1 if you install keypoint-moseq using pip. If you install with conda (i.e. conda env create -f conda_envs\environment.win64_gpu.yml, then conda will build its own version of CUDA for that env with the right versions of everything.
In terms of the kernel dying, if it is limited to the PCA step and doesn't occur during any of the subsequent steps and only happens when jax is installed with GPU support but not CPU; then the easiest thing might be to just run PCA in the CPU env, restart the kernel, and then pick up after PCA in the GPU env...
We can also try to dive deeper into what is going on. The code below is what gets when you fit PCA. Does it cause the kernel to restart? If so, which line?
from jax_moseq.models.keypoint_slds import preprocess_for_pca from sklearn.decomposition import PCA import numpy as np
Y = data['Y'] mask = data['mask'] conf = data['conf']
cfg = config() anterior_idxs = cfg['anterior_idxs'] posterior_idxs = cfg['posterior_idxs'] conf_threshold = cfg['conf_threshold'] fix_heading = cfg['conf_threshold'] verbose = cfg['conf_threshold'] PCA_fitting_num_frames = cfg['PCA_fitting_num_frames']
Y_flat = preprocess_for_pca( Y, anterior_idxs, posterior_idxs, conf, conf_threshold, fix_heading, verbose)[0] if conf is None: pca_mask = mask else: pca_mask = jnp.logical_and(mask, (conf > conf_threshold).all(-1))
Y_flat = Y_flat[mask > 0] N = Y_flat.shape[0] N_sample = min(PCA_fitting_num_frames, N) sample = np.random.choice(N, N_sample, replace=False) Y_sample = np.array(Y_flat)[sample] pca = PCA().fit(Y_sample)
— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1548454610__;Iw!!C5qS4YX3!AczCgr6dRlt0EvSA2LsVoyEojkmFSD60tfrOJ5r8S7z2wUMUPug18vpZtid-tlj0p9iJ_WJx3FRIhoHkkJxGEjcu2GTG$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZURYE7QFTXWWL6YJKDXGKAGBANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!AczCgr6dRlt0EvSA2LsVoyEojkmFSD60tfrOJ5r8S7z2wUMUPug18vpZtid-tlj0p9iJ_WJx3FRIhoHkkJxGEgQsWpwU$. You are receiving this because you authored the thread.Message ID: @.***>
It definitely does matter which CUDA is indicated in the .yml file. But it doesn't matter which CUDA you have installed globally since conda will install its own.
Hi Caleb,
OK, I now have the GPU version installed on Linux with CUDA 12, and everything runs fine but I get the following error when I try to load the DeepLabCut data (copied all configs, csv files, videos, and folder structure from working local CPU Windows version):
dlc_results = 'dlc_project/pose_est_csv' # can be a file, a directory, or a list of files coordinates, confidences = kpms.load_deeplabcut_results(dlc_results)
ValueError Traceback (most recent call last) Cell In[7], line 3 1 # load data from DeepLabCut 2 dlc_results = 'dlc_project/pose_est_csv' # can be a file, a directory, or a list of files ----> 3 coordinates, confidences = kpms.load_deeplabcut_results(dlc_results) 5 # format data for modeling 6 data, labels = kpms.format_data(coordinates, confidences=confidences, **config())
I tried running it with just one input csv file, and received the same error (currently running with 30 csv files and 30 videos).
I'm not sure why I'm getting a different result than I did with my local run with the same data.
-Evan
From: Caleb Weinreb @.> Sent: Monday, May 15, 2023 12:59 PM To: dattalab/keypoint-moseq @.> Cc: Evan Vickers @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] jaxlib not installed (Issue #40)
It definitely does matter which CUDA is indicated in the .yml file. But it doesn't matter which CUDA you have installed globally since conda will install its own.
— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/40*issuecomment-1548491875__;Iw!!C5qS4YX3!CZm54HXAIGVTfQ9g6Ood8sMqHXL3ggEhBqBea0vdWuN-0S2WmX3v5ZM8gFmb_SNKIhpkZO2rlBPm6dOdiucE7ccov31I$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AJGCDZRSSGO5AJ5IWEGXUL3XGKDLRANCNFSM6AAAAAAX35X5VE__;!!C5qS4YX3!CZm54HXAIGVTfQ9g6Ood8sMqHXL3ggEhBqBea0vdWuN-0S2WmX3v5ZM8gFmb_SNKIhpkZO2rlBPm6dOdiucE7U3uzVYS$. You are receiving this because you authored the thread.Message ID: @.***>
That line has to be updated as of the latest version (see updated tutorial)
coordinates, confidences, bodyparts = kpms.load_deeplabcut_results(dlc_results)
Hi,
With the Windows GPU version installed, I keep getting the following error:
jax requires jaxlib to be installed
....when I run
import keypoint_moseq as kpms project_dir = 'demo_project' config = lambda: kpms.load_config(project_dir)
I uninstalled jaxlib in a previous attempt at the environment - how do I reinstall it, and why didn't the uninstall only affect the previous environment?
Thanks,
Evan