astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
135 stars 9 forks source link

Degraded accuracy at higher spins? #176

Open mreineck opened 10 months ago

mreineck commented 10 months ago

I wrote a little test looking at the round-trip accuracy of the forward and inverse SHTs and noticed a surprising behaviour: accuracy is really good up to spin 5, and then degrades rapidly in the spin 5 to 10 range. Afterwards the results appear to be more or less random.

For testing I used this scipt:

from jax import config
config.update("jax_enable_x64", True)
import numpy as np

from s2fft.sampling import s2_samples as samples
from s2fft.transforms import spherical
from s2fft.utils import signal_generator

L = 128
sampling = "mw"

rng = np.random.default_rng(42)

def l2error(a, b):
    return np.sqrt(np.sum(np.abs(a-b)**2)/np.sum(np.abs(a)**2))

for spin in range(0, 15):
    flm = signal_generator.generate_flm(rng=rng, L=L, L_lower=0, spin=spin, reality=False)
    f = spherical.inverse(flm, L, spin, sampling=sampling, method="jax",
        reality=False)
    flm2 = spherical.forward(f, L, spin, sampling=sampling, method="jax",
        reality=False)
    print(f"L2 round trip error at spin {spin}: {l2error(flm, flm2)}")

The results I get on my laptop (using te CPU implementation) are:

L2 round trip error at spin 0: 1.5624566824714117e-13
L2 round trip error at spin 1: 1.5970499209487677e-13
L2 round trip error at spin 2: 1.583213554485799e-13
L2 round trip error at spin 3: 1.5675383128072145e-13
L2 round trip error at spin 4: 1.7105724287507523e-13
L2 round trip error at spin 5: 1.4177072379540323e-12
L2 round trip error at spin 6: 2.475048071913624e-10
L2 round trip error at spin 7: 3.7531618339585e-10
L2 round trip error at spin 8: 5.129932216531392e-07
L2 round trip error at spin 9: 2.978416885023176e-05
L2 round trip error at spin 10: 0.2221563661222978
L2 round trip error at spin 11: 421.1189085144605
L2 round trip error at spin 12: 316133.21353526734
L2 round trip error at spin 13: 170645266.3349521
L2 round trip error at spin 14: 4599977078115.962

Am I using the code incorrectly?

CosmoMatt commented 10 months ago

Hi @mreineck, as far as I can tell it seems as though you're calling everything correctly. @jasonmcewen and I have noticed similar behaviour and it's something we plan to look into when we get some time.

My current best guess is that the recursion (equations 24-27 in our paper) at larger spin numbers starts to get around our renormalisation strategy, which could introduce floating point errors which compound. However, this is just speculation and we'll have to dig into it to really know what's going on.

For the time being I've opened a PR which warns users that precision may degrade at higher spin (N) values.

In any case, most users likely will not encounter spins this high even for directional convolutions N ~ 5 is often more than sufficient.

mreineck commented 10 months ago

Thanks a lot, @cosmomatt!

My current best guess is that the recursion (equations 24-27 in our paper) at larger spin numbers starts to get around our renormalisation strategy, which could introduce floating point errors which compound. However, this is just speculation and we'll have to dig into it to really know what's going on.

I'll try to have a look myself, and will let you know if I find something.

In any case, most users likely will not encounter spins this high even for directional convolutions N ~ 5 is often more than sufficient.

In most cases I agree, but it's still good to have a warning. During the Planck mission, we often ran the equivalent of a Wigner tranform up to spins of 20 to simulate the effect of beam far side lobes (see, e.g., https://arxiv.org/abs/2201.03478). For LiteBIRD this limit may even increase.

CosmoMatt commented 10 months ago

Fantastic, thanks @mreineck I'll try and find some time soon to look into this.