Closed paddyroddy closed 9 months ago
Thanks @paddyroddy. @CosmoMatt can you comment on this?
@paddyroddy several things here.
L = 4096
(which is massive) having to wait 1 second is slower than s2let
but its still just 1 second (hence we haven't worried about optimising anything here --- at least not yet). k_lam_jax
and the fact that this function is forced to JIT compile on this line. This is a recipe for extremely slow compile, as the loop must unroll, not to mention that memory can be large in these cases.block_until_ready()
you should at least get a more reasonable estimate of the execution time.Out of interest if you comment out the JAX function (i.e. s2wav.filter_factory.filters.filters_axisym_jax(L, J_MIN, B)
) does everything run through (albeit slower than their C counterparts)?
Just given that a go, and yes, removing filters_axisym_jax
sped it up a lot.
Whilst I understand L=4096
might be considered in massive, for the Slepian Wavelets I was regularly dealing with L=16384
(i.e. 128^2
) due to the requirement to calculate a L^2 * L^2
matrix to solve the Slepian functions.
I had a go porting to s2wav
/s2fft
out of interest, which I've since given up on as it just was too slow at these very high L
values.
Ah ok, so our end transforms are limited to around L=4096 why is why we weren't too worried about this issue. If you would rather, you could probably compute the filters using s2let and then use them to avoid this problem? @paddyroddy unless you need to differentiate through the tiling?
Don't worry. I was merely interested to see how s2wav
/s2fft
work and whether I could replace pys2let
/pyssht
- motivated by if the others get deprecated / offering Windows support for sleplet
. But seems like that isn't possible (at least for now), so I'll stick with the others.
Ok, if I understand correctly, this is just for the construction of filters. This typically only needs to be performed once, hasn't yet been optimised and is still just fractions of a second for fairly high ell. There's little point in jitting here for a single call since the time will be taken up by jit compilation. If these functions are jitted @CosmoMatt, perhaps we should just remove the jit decorator? Otherwise I think this is fine for now so I'll close this issue.
Not sure if I'm missing something or the library is designed only for massive GPUs. I'm running on an M1 MacBook Pro with top specs. This is specifically for
s2wav.filter_factory.filters.filters_axisym...
, but I've noticed similar for other functions (&s2fft
).The following works fine for me from
L=16
toL=2048
. It is fine, but feels a little slow.Now if we ramp up to
L=4096
, it returns something like1e-02s
, ands2wav
hangs (and at higherL
just never finishes).