dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
71 stars 28 forks source link

Question on performance the test dataset #124

Open vigji opened 8 months ago

vigji commented 8 months ago

Hello! First of all, thank you very much for the very useful resource and for the effort in making it accessible to the community!

I have installed keypoint in a dedicated conda environment made using the environment.win64_gpu.yml file. My computer is a Windows 10, just updated to latest version, with GeForce 1070 Ti graphics card (8 GB RAM) and the latest available drivers (546.33).

From conda list:

...
cuda-nvcc                 12.3.107                      0    nvidia
cuda-version              11.8                 h70ddcb2_2    conda-forge
cudatoolkit               11.8.0              h09e9e62_12    conda-forge
cudnn                     8.8.0.121            h84bb9a4_4    conda-forge
...
jax                       0.3.22                   pypi_0    pypi
jax-moseq                 0.2.1                    pypi_0    pypi
jaxlib                    0.3.22                   pypi_0    pypi
...
keypoint-moseq            0.4.2                    pypi_0    pypi

If I do python -c "import jax; print(jax.default_backend()) I get the current gpu result.

I tried out the tutorial workflow. I have <0.3 GB RAM memory usage before starting, that get to 7.3 GB after Jax initialization, and remains stably there. The initialization and the AR-HMM model fit with 50 iterations runs smoothly in ~13 mins. When I start the fit of the whole model it crashes silently.

Assuming an OOM error I have set parallel_message_passing=False and now it runs (is it correct that the test dataset is 643911 frames? In which case shouldn't I be fine with 8 GB > 6.5 GB at 0.01 MB/frame as per the faq?)

With that, it runs on the test dataset in approx 8 hours; is this what you would expect? (it sounds reasonable for the upper edge of the 2-5x slowdown of the 1.5h estimate on google colab). I just am a bit surprised as this is significantly slower than what I would get using the cpu version (approx 4.5 hours). I guess this is fine as my actual data has the comparable size of 500k frames, reporting here just to make sure I'm not missing anything and in case it could be useful for others.

Thank you very much for your clarifications and for all the nice and useful work!