flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

ci: use cuda from pip #63

Closed lgarrison closed 7 months ago

lgarrison commented 7 months ago

In the GPU CI, we're currently using CUDA from the nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 Docker image. JAX tends to stay pretty bleeding-edge in terms of CUDA requirements though, so our choices are to manually update the image each time the installation breaks, or we just install the CUDA binaries from pip. The latter feels more robust to me, so we'll do that for now, but the tradeoff is that the CI takes an extra ~1 minute.

Also, that this means we're compiling jax-finufft against CUDA 12.2 and running it against the 12.3 libraries, but CUDA has good forward-compatibility guarantees so this shouldn't be a problem.

lgarrison commented 7 months ago

Should fix the CI error in #62.