In the GPU CI, we're currently using CUDA from the nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 Docker image. JAX tends to stay pretty bleeding-edge in terms of CUDA requirements though, so our choices are to manually update the image each time the installation breaks, or we just install the CUDA binaries from pip. The latter feels more robust to me, so we'll do that for now, but the tradeoff is that the CI takes an extra ~1 minute.
Also, that this means we're compiling jax-finufft against CUDA 12.2 and running it against the 12.3 libraries, but CUDA has good forward-compatibility guarantees so this shouldn't be a problem.
In the GPU CI, we're currently using CUDA from the
nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04
Docker image. JAX tends to stay pretty bleeding-edge in terms of CUDA requirements though, so our choices are to manually update the image each time the installation breaks, or we just install the CUDA binaries from pip. The latter feels more robust to me, so we'll do that for now, but the tradeoff is that the CI takes an extra ~1 minute.Also, that this means we're compiling jax-finufft against CUDA 12.2 and running it against the 12.3 libraries, but CUDA has good forward-compatibility guarantees so this shouldn't be a problem.