Open matt-graham opened 1 month ago
Hello @matt-graham I think that the gradient is straight forward since the FFT is linear The adjoint autograd for the forward pass has to be a spectral folding followed by FFTs in the reverse order correct? For vmap this might be a bit challenging since there is only one cudastream provided by XLA , and from experience forking a stream has some overhead to it. Let me think about it. Some of the JAX guys suggest I use Pallas instead of cuda which solves the latter issue don't know what you guys think about it.
Yes, as you say @ASKabalan , since operations are linear, gradients can be computed via inverse transforms. Precisely how to do this is outlined in our paper here in Section 5. That could be a good approach here. This is also how we implemented the VJPs for the C wrappers that we added to s2fft
, although that trades off accurate between the forward passes and VJPs when adding in HEALPix interations (which we should do soon).
I think the primitive added in #204 may not support automatic differentiation and batching transforms as we did not define Jacobian vector product and transpose operations (for autodiff) and a batcher (for
vmap
support). We should verify if this is the case and add implementations if necessary.