Closed lgarrison closed 11 months ago
Amazing. I will report back as soon as I test things out on my side.
Works for me now. Thank you again.
Interestingly, it looks like the same bug was fixed on the CPU a few years ago (d4622b861dac6dbf96239cff5f2650b3d74d5921), but that was after the GPU fork, and we never ported the fix to the GPU code. Oops!
jax-finufft has two levels of batching: an inner level where finufft does multiple transforms for the same set of NU points (
n_transf > 1
), and an outer level where jax-finufft does multiple transforms with different NU points (looping overn_tot
). Therefore the NU points arrays have shape[n_tot, n_j]
, and the source array has shape[n_tot, n_transf, n_j]
. However, the NU points arrays were being indexed as if they had the latter shape. This was leading to out-of-bounds memory accesses on the GPU when trying to usen_tot > 1
, e.g. as a result ofjax.vmap
.This fixes the GPU runtime error in #37.
The submodule update points us at a more recent finufft. We're still using at a fork while we wait for upstream work to finish, but this update brings us much closer to the current state of the upstream. The fork also uses fewer threads per block in certain register-intensive 3D operations, which should fix CUDA errors about not enough resources.
CC @Matematija