flatironinstitute / jax-finufft

JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
Apache License 2.0
80 stars 3 forks source link

Fix indexing of nupts for batched nufft (`n_tot > 1`) #47

Closed lgarrison closed 11 months ago

lgarrison commented 11 months ago

jax-finufft has two levels of batching: an inner level where finufft does multiple transforms for the same set of NU points (n_transf > 1), and an outer level where jax-finufft does multiple transforms with different NU points (looping over n_tot). Therefore the NU points arrays have shape [n_tot, n_j], and the source array has shape [n_tot, n_transf, n_j]. However, the NU points arrays were being indexed as if they had the latter shape. This was leading to out-of-bounds memory accesses on the GPU when trying to use n_tot > 1, e.g. as a result of jax.vmap.

This fixes the GPU runtime error in #37.

The submodule update points us at a more recent finufft. We're still using at a fork while we wait for upstream work to finish, but this update brings us much closer to the current state of the upstream. The fork also uses fewer threads per block in certain register-intensive 3D operations, which should fix CUDA errors about not enough resources.

CC @Matematija

Matematija commented 11 months ago

Amazing. I will report back as soon as I test things out on my side.

Matematija commented 11 months ago

Works for me now. Thank you again.

lgarrison commented 11 months ago

Interestingly, it looks like the same bug was fixed on the CPU a few years ago (d4622b861dac6dbf96239cff5f2650b3d74d5921), but that was after the GPU fork, and we never ported the fix to the GPU code. Oops!