Jax requires cuDNN 8.9, and at least based package list available on Ubuntu this also implies CUDA 12.1. I didn't find an explicit reason for 8.9, but it seems to address possible issues with the current precompiled archive, so I think it's worth bumping the requirement.
Jax requires cuDNN 8.9, and at least based package list available on Ubuntu this also implies CUDA 12.1. I didn't find an explicit reason for 8.9, but it seems to address possible issues with the current precompiled archive, so I think it's worth bumping the requirement.