astro-informatics / s2wav

Differentiable and accelerated wavelet transform on the sphere with JAX
https://astro-informatics.github.io/s2wav/
MIT License
12 stars 0 forks source link

`s2wav` is slow for filter construction at moderate `L` #77

Closed paddyroddy closed 9 months ago

paddyroddy commented 9 months ago

Not sure if I'm missing something or the library is designed only for massive GPUs. I'm running on an M1 MacBook Pro with top specs. This is specifically for s2wav.filter_factory.filters.filters_axisym..., but I've noticed similar for other functions (& s2fft).

The following works fine for me from L=16 to L=2048. It is fine, but feels a little slow.

import pys2let
import s2wav
import time

B = 3
J_MIN = 2
ELL_MIN = 4
ELL_MAX = 11

for ell in range(ELL_MIN, ELL_MAX + 1):
    L = 2**ell
    t0 = time.time()
    pys2let.axisym_wav_l(B, L, J_MIN)
    t1 = time.time()
    s2wav.filter_factory.filters.filters_axisym(L, J_MIN, B)
    t2 = time.time()
    s2wav.filter_factory.filters.filters_axisym_vectorised(L, J_MIN, B)
    t3 = time.time()
    s2wav.filter_factory.filters.filters_axisym_jax(L, J_MIN, B)
    t4 = time.time()
    print(
        f"L={L:>4} | pys2let={t1-t0:.0e}s | s2wav={t2-t1:.0e}s | "
        f"s2wav_vec={t3-t2:.0e}s | s2wav_jax={t4-t3:.0e}s",
    )
# L=  16 | pys2let=7e-05s | s2wav=7e-03s | s2wav_vec=7e-03s | s2wav_jax=9e-02s
# L=  32 | pys2let=9e-05s | s2wav=1e-02s | s2wav_vec=1e-02s | s2wav_jax=2e-01s
# L=  64 | pys2let=2e-04s | s2wav=3e-02s | s2wav_vec=3e-02s | s2wav_jax=4e-01s
# L= 128 | pys2let=4e-04s | s2wav=5e-02s | s2wav_vec=5e-02s | s2wav_jax=8e-01s
# L= 256 | pys2let=7e-04s | s2wav=1e-01s | s2wav_vec=1e-01s | s2wav_jax=2e+00s
# L= 512 | pys2let=1e-03s | s2wav=2e-01s | s2wav_vec=2e-01s | s2wav_jax=5e+00s
# L=1024 | pys2let=3e-03s | s2wav=4e-01s | s2wav_vec=4e-01s | s2wav_jax=2e+01s
# L=2048 | pys2let=7e-03s | s2wav=8e-01s | s2wav_vec=8e-01s | s2wav_jax=5e+01s

Now if we ramp up to L=4096, it returns something like 1e-02s, and s2wav hangs (and at higher L just never finishes).

import pys2let
import s2wav
import time

B = 3
J_MIN = 2
L = 4096

t0 = time.time()
pys2let.axisym_wav_l(B, L, J_MIN)
t1 = time.time()
print(f"{t1 - t0:.0e}")
jasonmcewen commented 9 months ago

Thanks @paddyroddy. @CosmoMatt can you comment on this?

CosmoMatt commented 9 months ago

@paddyroddy several things here.

  1. Most important to note: For our purposes here, the intention is that the filters are generated once and stored, then multiple wavelet transforms are evaluated. To that end, at L = 4096 (which is massive) having to wait 1 second is slower than s2let but its still just 1 second (hence we haven't worried about optimising anything here --- at least not yet).
  2. JAX functions: The JAX filters functions aren't really stress tested yet and are very rough around the edges. Particularly if you notice both this loop in k_lam_jax and the fact that this function is forced to JIT compile on this line. This is a recipe for extremely slow compile, as the loop must unroll, not to mention that memory can be large in these cases.
  3. Your test: Also keep in mind that the way you are running your test you are actually measuring the JIT compile time + one evaluation of the JAX function (see here). You'll still run into the compile time + memory issues mentioned in (2) but if you use block_until_ready() you should at least get a more reasonable estimate of the execution time.

Out of interest if you comment out the JAX function (i.e. s2wav.filter_factory.filters.filters_axisym_jax(L, J_MIN, B)) does everything run through (albeit slower than their C counterparts)?

paddyroddy commented 9 months ago

Just given that a go, and yes, removing filters_axisym_jax sped it up a lot.

Whilst I understand L=4096 might be considered in massive, for the Slepian Wavelets I was regularly dealing with L=16384 (i.e. 128^2) due to the requirement to calculate a L^2 * L^2 matrix to solve the Slepian functions.

I had a go porting to s2wav/s2fft out of interest, which I've since given up on as it just was too slow at these very high L values.

CosmoMatt commented 9 months ago

Ah ok, so our end transforms are limited to around L=4096 why is why we weren't too worried about this issue. If you would rather, you could probably compute the filters using s2let and then use them to avoid this problem? @paddyroddy unless you need to differentiate through the tiling?

paddyroddy commented 9 months ago

Don't worry. I was merely interested to see how s2wav/s2fft work and whether I could replace pys2let/pyssht - motivated by if the others get deprecated / offering Windows support for sleplet. But seems like that isn't possible (at least for now), so I'll stick with the others.

jasonmcewen commented 9 months ago

Ok, if I understand correctly, this is just for the construction of filters. This typically only needs to be performed once, hasn't yet been optimised and is still just fractions of a second for fairly high ell. There's little point in jitting here for a single call since the time will be taken up by jit compilation. If these functions are jitted @CosmoMatt, perhaps we should just remove the jit decorator? Otherwise I think this is fine for now so I'll close this issue.