astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
137 stars 9 forks source link

Explore removing loop in HEALPix FFTs (to reduce JIT compile time for high L) #140

Open jasonmcewen opened 1 year ago

EiffL commented 1 year ago

Yep I think that can be improved a bit

EiffL commented 1 year ago

oookkkk I admit defeat.... I can't figure out how to do it, I was hoping some clever uses of https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html would work, but no, the size has to be statically known.... CodeGPT/ChatGPT/Copilot were no help at all :-( (there might be some use cases for good old humans after all ^^')

Of course, the nukular option remains... One could implement healpix_fft_jax directly in cuda, and wrap it in JAX.... It's only a matter of running a bunch of ffts so it's pretty trivial. Here is an example of how it's done: https://github.com/dfm/extending-jax

But that brings two questions:

For now though, I guess living with the for loop version is ok. I think the compilation time is quite faster now that I'm running it on jax v0.4

jasonmcewen commented 1 year ago

The nuclear option is a possibility but as a very last resort I would say. As you say, that is not going to be the easiest to maintain. But if that is the only way, we could do it. I think implementing the minimal amount of code in cuda would be best, which would help with maintainability and shouldn't have any further performance impact, so basically just healpix_fft_jax etc. But I think we should only do this as a last report. It would be worth trying to explore some other alternatives first.

Beyond a more efficient JAX implementation of the current HEALPix algorithm, we could perhaps consider variants of the algorithm that avoid this issue or at least mitigate it, e.g. perhaps we could compute a fixed higher resolution ring FFT and downsample rings. Will need to look into this in further detail and give it some more thought.

CosmoMatt commented 1 year ago

oookkkk I admit defeat.... I can't figure out how to do it, I was hoping some clever uses of https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html would work, but no, the size has to be statically known.... CodeGPT/ChatGPT/Copilot were no help at all :-( (there might be some use cases for good old humans after all ^^')

Of course, the nukular option remains... One could implement healpix_fft_jax directly in cuda, and wrap it in JAX.... It's only a matter of running a bunch of ffts so it's pretty trivial. Here is an example of how it's done: https://github.com/dfm/extending-jax

But that brings two questions:

  • Would you be happy for s2fft to also have some custom compiled ops? The potential problem is that it would require some compilation as opposed to a pure jax library. Plus, long term maintenance has higher cost because you have to keep up with the undocumented jax custom op api.
  • As long as we are implementing things in CUDA, would it be smarter to directly wrap a full SHT for healpix as opposed to just the ring-wise FFTs?

For now though, I guess living with the for loop version is ok. I think the compilation time is quite faster now that I'm running it on jax v0.4

Ah that's a shame I was holding out hope you might find a solution! I've already sifted through many of the 'obvious' work-arounds, and even a few hacky ones, but it's not at all clear how best to handle this. I suspect we may want to re-engineer the FFT component of the HEALPix transform which may introduce some error; HEALPix is only approximate regardless, so perhaps this is a fair trade-off in this case.

The reason we strayed away from doing anything like that so far is that we wanted to ensure we recreated the transforms presented in existing code, but in JAX. If we make these changes the transforms may diverge somewhat...