astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
137 stars 9 forks source link

Stable recursive transforms for high azimuthal band limits #227

Open CosmoMatt opened 1 month ago

CosmoMatt commented 1 month ago

@ElisR picking up our previous discussion here.

Problem

It seems that you would like to efficiently evaluate the inverse Wigner transforms for L = N = 64? Are you interested in pushing this to higher L or is 64 sufficient? Currently the underlying recursion runs into stability issues for azimuthal band limits (N) greater than ~8, depending on L. Not clear precisely why, and for many use-cases this is acceptable, so we put priority on other projects. In any case, good to pick this up now!

Background

As you suggested, retrofitting a more stable Wigner recursion for high N use-cases is the way to go. I've already implemented a JAX version of the Risbo recursion (code, paper). @jasonmcewen and I have used this recursion in previous work (e.g. paper) we have used this and found it is far more stable as N increases. Combining this recursion with the Fourier decomposition of Wigner D-functions presented by Trapani & Navaza (e.g. paper) we designed efficient FFT based implementations for the spin-spherical harmonic (paper equations 9-14) and Wigner transforms (paper section III).

Solution (Small L <= 64)

For your exact use case L = N = 64 is low enough that it should be possible to precompute all the necessary Wigner d-functions, which will then be much faster at run-time. This can be done at high N in a stable manner by using the Risbo recursion and following equation 8 of this paper. What we'll need to do is:

This will make the precompute take somewhat longer, but the memory complexity should remain the same $\mathcal{O}(NL^3)$ and the speed at run-time should be unaffected. We can reduce the memory complexity to $\mathcal{O}(L^3)$, which would be really good for your use case, but this will reduce the speed a little at runtime. In any case, I'll implement both versions so you can choose. What we'll need to do is:

image

Solution (Large L > 64)

We already have the core recursion necessary to solve the problem. What we'll need to do is:

I'll also need to check that memory/time complexity plays well for gradient propagation.

CosmoMatt commented 1 month ago

What I'd suggest, is if I implement the precompute approach first (starting presently) this should act as a stop-gap whilst I subsequently work on implementing the more scalable recursive approach.

jasonmcewen commented 1 month ago

Thanks for the clear overview @CosmoMatt. For reference, while the Fourier decomposition of Wigner D-functions was used in Trapani & Navaza it was known well before that.

ElisR commented 1 month ago

Hi @CosmoMatt - thanks for the summary and for pointing me towards some of the literature.

Upon first inspection I think this should be doable for me if I dedicate a few hours to it this week (though apologies if it spills into next week).

Regarding my use case, $L = N = 64$ is probably the minimum I need, but I would also benefit from $L = 128$. I'm not actually differentiating through this transform, so gradients are not critical for me, though I am of course in favour of a solution that works for all. The reason I need these high angular resolutions is because I'm aligning some spherical signals when calculating some training metrics, and $180^{\circ} / 64 \approx 3^{\circ}$ is around the acceptable threshold for my current signals.

Memory Estimate

I am just catching up so I'll run through a back-of-the-napkin calculation of memory lower bounds for a naive precomputation of all $d_{mn}^{\ell}(\beta)$. Please correct if I'm missing some symmetries / making a mistake.

Suppose $d{mn}^{\ell}(\beta)$ is stored as one dense matrix for all $\ell <= \ell{\text{max}} = 128 = 2^7$, then we have $128 \times (2 \times 128)^2 = 2^{23}$ matrix elements for each $\beta$. Since $d^{\ell}(\beta)$ is real, we need one real double (i.e. $8 = 2^3$ bytes) for each matrix element. Hence each $\beta$ requires $2^{23} \times 2^3 = 2^{26}$ bytes i.e. 64MiB for each $\beta$. Adding another factor of $2^7$ if we include all $\beta$, then we get $2^{33}$ bytes, i.e. 8GiB for the entire precomputation.

For my hardware, this is actually acceptable! I hence support the plan of just swapping out the recursions for Risbo and seeing if things behave nicely without changing the logic to store $\Delta_{m n }^{\ell}$ etc.

Tests

I still need to look at the codebase a bit more to see what tests need to be modified to fairly assess the $N > 8$ performance, but I'm guessing that I can just follow the pattern of comparing to SSHT. I will also test on the small example code that I posted in #209.

CosmoMatt commented 1 month ago

@ElisR no worries, yep those back of those envelope memory calulcations look reasonable. You can see the issue with scalability here, given the O(L^4) memory scaling!

I actually spent some time tinkering over the weekend and should be able to push a MVP for this switch very shortly. There will of course still be some work to ensure it works for your application.

ElisR commented 1 month ago

Yep, I'm lucky to be one factor of 2 away from silly numbers 😅 .

Okay, that's great to hear, sounds good!