Closed bainro closed 5 months ago
I'll try installing another cuda-nvcc
conda-forge dropped nvcc 11.* packages, but they're still available thru the nvidia channel:
Installing cuda-nvcc == 11.8.89 through the nvidia channel upgraded jax to 0.3.25, but it seems to be working now :) need to debug just a little more to ensure before closing ticket/issue.
Training is working :+1:
Thanks for reporting this and all the detailed notes! Just updated the env file to incorporate your suggestions https://github.com/dattalab/keypoint-moseq/commit/a3c569381dc1ce8040dce1afd6ac7ddc945ce187
I've had trouble installing on my linux machine for a while. I just settled for the CPU only env, but it was slow and would always fail at some point in training too. I'm trying to speed things up with the GPU again, but I am stuck at the following error from jax:
You can see my cuda toolkit version and drivers in that image. I am not sure if that's the one found through my anaconda env though, but I made no modifications to the conda yaml file. The only things I did was change etils to version 1.5.2