flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
77 stars 2 forks source link

CUDA wheels #88

Open lgarrison opened 4 months ago

lgarrison commented 4 months ago

Just starting to write down my thoughts on how we could build and distribute GPU wheels.

For background on GPU wheels, this is the best summary of the current state of affairs I've found: https://pypackaging-native.github.io/key-issues/gpus/

It seems there are two broadly different approaches we could take:

  1. bundle CUDA in the wheel, or
  2. use "external" CUDA.

(1) is the traditional method and results in large wheels. (2) is what JAX does, but is pretty cutting-edge and under-documented.

The way (1) would work is that we would statically link the CUDA libraries (which is what we're currently doing, I think), or dynamically link but let auditwheel copy the libraries into the wheel. There's a few parts of this I still don't understand, such as how it would work with JAX linking against one CUDA runtime but jax-finufft potentially having another. Would that result in two CUDA contexts? Clearly it's already working somehow!

With (2), we would use the NVIDIA CUDA wheels on PyPI. I can't find any official documentation on them, but the Python CUDA tech lead did write this nice tutorial in the cuQuantum repo: https://github.com/NVIDIA/cuQuantum/tree/main/extra/demo_build_with_wheels

It's somewhat hacky, but the basic ideas are clear. We would set the rpath to find the pip-installed CUDA libraries, using some helper scripts and auditwheel --exclude to allow specific shared libraries. At runtime, the linker will look for the pip-installed CUDA, or user/system installations if that fails.

Either way, I think the build itself can be done on cibuildwheel, probably just with a yum install of the CUDA development libraries (like this project does: https://github.com/OpenNMT/CTranslate2/blob/master/python/tools/prepare_build_environment_linux.sh).

In terms of PyPI distribution, with CUDA minor version compatibility, I think we can just do what cupy does and use jax-finufft-cuda12x and jax-finufft-cuda11x (if we want to support CUDA 11); no need for a custom package index URL. With (2), we would use [cuda_local] and [cuda_pip] extras. I don't think we need a full matrix of cuDNN versions like JAX does, but I could be wrong.

This is all a bit experimental since this isn't a "vanilla" CUDA extension, but one that has to work with JAX! For that reason, (2) seems more appealing, since it seems more likely to find the same CUDA JAX does more often than not.