astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
138 stars 9 forks source link

Check autodiff and batching support for `healpix_fft_cuda` primitive and add if needed #237

Open matt-graham opened 1 month ago

matt-graham commented 1 month ago

I think the primitive added in #204 may not support automatic differentiation and batching transforms as we did not define Jacobian vector product and transpose operations (for autodiff) and a batcher (for vmap support). We should verify if this is the case and add implementations if necessary.

ASKabalan commented 1 month ago

Hello @matt-graham I think that the gradient is straight forward since the FFT is linear The adjoint autograd for the forward pass has to be a spectral folding followed by FFTs in the reverse order correct? For vmap this might be a bit challenging since there is only one cudastream provided by XLA , and from experience forking a stream has some overhead to it. Let me think about it. Some of the JAX guys suggest I use Pallas instead of cuda which solves the latter issue don't know what you guys think about it.

jasonmcewen commented 1 month ago

Yes, as you say @ASKabalan , since operations are linear, gradients can be computed via inverse transforms. Precisely how to do this is outlined in our paper here in Section 5. That could be a good approach here. This is also how we implemented the VJPs for the C wrappers that we added to s2fft, although that trades off accurate between the forward passes and VJPs when adding in HEALPix interations (which we should do soon).