dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
68 stars 28 forks source link

Fitting an AR-HMM only runs 1 iteration instead of 50 as specified #26

Closed juliagorman closed 1 year ago

juliagorman commented 1 year ago

Despite specifying 50 iterations of model fitting, model only runs one iteration and then moves on to the next line of code. This happens with different files I have tried.

Screenshot 2023-04-25 at 11 30 42 AM
calebweinreb commented 1 year ago

Fitting terminated because NaNs were detected, as shown in the error message. The troubleshooting docs describe some options for dealing with this.

juliagorman commented 1 year ago

I have checked the troubleshooting doc, and I did not get the value True but I am still running into this issue. I only see this as an option for dealing with NaNs on the page

calebweinreb commented 1 year ago

If the command jax.config.read('jax_enable_x64') returned False, then you are using single-precision and that would explain the NaNs. To use double precision, update keypoint-moseq and jax-moseq to their latest versions and restart the notebook kernel:

pip install -U jax-moseq
pip install -U keypoint-moseq

Once you have restarted the notebook and imported keypoint_moseq, you can use the follow to check the precision. It should return True.

import jax
jax.config.read('jax_enable_x64')

If it does't, you can also manually set JAX to use double-precision using

jax.config.update('jax_enable_x64', True)
juliagorman commented 1 year ago

Okay I imported as instructed above. The command still returned False, so I manually set JAX as instructed as shown in my screenshot. Still I got the same error when training the AR-HMM and could run more than one iteration. I ran both with all three lines uncommented and then tried again commenting out only the line: jax.config.read('jax_enable_x64'). Sorry for all these issues!

Screenshot 2023-04-26 at 11 18 36 AM
juliagorman commented 1 year ago

Also, I tried it in this order such that it then outputted True, and still I got this error with fitting.

Screenshot 2023-04-26 at 11 46 25 AM Screenshot 2023-04-26 at 11 46 59 AM
calebweinreb commented 1 year ago

A couple thoughts:

  1. Unless you're using google colab, I would recommend running the pip update commands outside the notebook, i.e. in another terminal window that has the keypoint_moseq conda environment activated. After that, you will need to restart the kernel of the notebook screenshotted above.

  2. After updating the code, make sure you re-run the model fitting notebook from the beginning. i.e. don't use a saved checkpoint.

  3. If you are using colab, then you should delete and disconnect the runtime, and just start over with a new runtime using the install commands as they appear in the example colab. Also it seems from your screenshot like the GPU is not being used. The code will run much faster if you select GPU for the runtime type, as explained in the linked colab.

juliagorman commented 1 year ago

Hi Caleb,

I am using Google Colab. Attached is my colab notebookhttps://colab.research.google.com/drive/1Fi5wkmb9nGyKt619rP0UZcwz0x5zhSco?usp=sharing + filehttps://drive.google.com/file/d/1wz5WVQ9eVb6SiHAXt7wC9yjuT_a25kQu/view?usp=sharing I am using for analysis. I am still running into the same issues. I created this notebook with the exact same set up as the example notebook you linked. I was unable to access the GPU and I am unsure why. I also still am running into the exact same error at the model fitting iteration 1 checkpoint.

JC Gorman PhD student | Millerhttps://millerlab.ucsd.edu/ lab Neurosciences Graduate Program | UCSD


From: Caleb Weinreb @.> Sent: Thursday, April 27, 2023 12:50 AM To: dattalab/keypoint-moseq @.> Cc: Julia C Gorman @.>; Author @.> Subject: Re: [dattalab/keypoint-moseq] Fitting an AR-HMM only runs 1 iteration instead of 50 as specified (Issue #26)

A couple thoughts:

  1. Unless you're using google colab, I would recommend running the pip update commands outside the notebook, i.e. in another terminal window that has the keypoint_moseq conda environment activated. After that, you will need to restart the kernel of the notebook screenshotted above.

  2. After updating the code, make sure you re-run the model fitting notebook from the beginning. i.e. don't use a saved checkpoint.

  3. If you are using colab, then you should delete and disconnect the runtime, and just start over with a new runtime using the install commands as they appear in the example colabhttps://urldefense.com/v3/__https://colab.research.google.com/github/dattalab/keypoint-moseq/blob/main/docs/keypoint_moseq_colab.ipynb__;!!Mih3wA!G2MZhKAFNhv46QZ40GZbkmsXG5eNVx-aNEL375vsJqdb2lFoYgBv6i65T7eUjzMm_uAnGijnUeXUdfkTS6yTa27ehQ$. Also it seems from your screenshot like the GPU is not being used. The code will run much faster if you select GPU for the runtime type, as explained in the linked colab.

— Reply to this email directly, view it on GitHubhttps://urldefense.com/v3/__https://github.com/dattalab/keypoint-moseq/issues/26*issuecomment-1525017961__;Iw!!Mih3wA!G2MZhKAFNhv46QZ40GZbkmsXG5eNVx-aNEL375vsJqdb2lFoYgBv6i65T7eUjzMm_uAnGijnUeXUdfkTS6xYWjiqEw$, or unsubscribehttps://urldefense.com/v3/__https://github.com/notifications/unsubscribe-auth/AHXPYCJUWISHP5HGVPSUQM3XDIQLZANCNFSM6AAAAAAXLMG7AM__;!!Mih3wA!G2MZhKAFNhv46QZ40GZbkmsXG5eNVx-aNEL375vsJqdb2lFoYgBv6i65T7eUjzMm_uAnGijnUeXUdfkTS6yHQkBJoQ$. You are receiving this because you authored the thread.Message ID: @.***>

juliagorman commented 1 year ago

However, I do not run into this issue when running jupyter notebook so for the time being I have just switched to jupyter for the time being.

juliagorman commented 1 year ago

Oop I spoke to soon I actually start getting the same issue further down the code now and the applying the trained model step:

Screenshot 2023-04-27 at 2 51 17 PM
calebweinreb commented 1 year ago

I'll look into the jupyter notebook. For the issue above, see https://github.com/dattalab/keypoint-moseq/issues/15#issuecomment-1503626291

tldr; you can avoid this by setting "num_iters=0"

results = kpms.apply_model(coordinates=coordinates, confidences=confidences, 
                           project_dir=project_dir, **config(), **checkpoint, num_iters=0)