astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
137 stars 9 forks source link

User report #142

Open EiffL opened 1 year ago

EiffL commented 1 year ago

Dear s2fft devs,

Thanks for this great package! I just wanted to report on some of the pain-points I ran into to hopefully provide constructive feedback.

SHT with inappropriate healpix nside returns mysterious error

Here is what I tried to do:

import healpy as hp
import s2fft
sampling= "healpix"

nside = 256 
lmax = 256

# Sample gaussian map
z = random.randn(hp.nside2npix(nside))

# Convert to alms
alm = s2fft.forward_jax(z, lmax, 
                        reality=True, 
                        sampling=sampling, 
                        nside=nside)

triggers an error that ends in:

File ~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:4672, in _check_shapelike(fun_name, arg_name, obj, non_zero_shape)
   4670 if not all(core.greater_equal_dim(d, lower_bound) for d in obj_arr):
   4671   msg = "{} {} must have every element be {}, got {}."
-> 4672   raise TypeError(msg.format(fun_name, arg_name, bound_error, obj))

TypeError: iota shape must have every element be nonnegative, got (-2,).

jitting time is long

I'm sure this is something you are well aware of, and I guess that's also one of the reasons for the 2 stages way of doing the transform. I've found that it can take up to 3 mins on my laptop to compile the forward_jax.

It's no big deal though, and there are many ways to accelerate these kind of things and I'll be very happy to take a look at it to see if I can suggest some improvements in #140

CosmoMatt commented 1 year ago

Hi @EiffL the first error rises because we expect L >= 2*nside. We had a assert for this originally but this must have been lost during development, so definitely something we should add back with an appropriate error message! (I've added an issue explicitly for this).

JIT compilation of forward transform

The slow JIT compilation is a more critical issue (which we are aware of) and there are two points to consider:

1) the forward transform involves a few more steps than the inverse so for all sampling it will take a little bit longer to compile, and execution will be marginally slower, but for all but HEALPix sampling this should be minimal. 2) we are computing SHTs by separation of variables, e.g. there is a latitudinal step which involves Wigner-d recursions, and a longitudinal step which is 'just' an FFT. For all sampling schemes the latitudinal step is the same, and is optimised to compile and run efficiently. For HEALPix the issue arises with the FFT portion. As HEALPix samples with rings of different nphi the length of these FFTs are not consistent, so instead of doing a simple jnp.fft.fft( f ) we have to loop over the rings and manually evaluate each FFT; a naive python loop in JAX when JIT compiled is unrolled which seriously slows down compilation and can lead to a more memory hungry compiled graph (though the execution speed shouldn't be effect noticeably). In theory one can compress these loops into lax.fori_loop statements which are fast to compile, however this requires that the operations within each ring are identical, which isn't the case as the number of phi samples (and hence the shape of the FFT) changes.

tl;dr: this should only be an issue for HEALPix sampling, and exists for both the forward and inverse transform, and its an issue of the FFT component of the transform.

Solutions

A short term solution would be to switch for another sampling scheme, none of the remaining schemes should exhibit this compilation issue. A longer term solution, which we are aware of and investigating, in #140 is to optimise this HEALPix FFT loop, perhaps by compressing into a lax.fori_loop or a scan.

EiffL commented 1 year ago

right right right... I see.... what would happen if one interpolated the healpix map on one of the other formats that have same number of points per rings? It's not super pretty.... but would this affect a lot the accuracy?

EiffL commented 1 year ago

of course.... one could also not use the healpix scheme ^^

CosmoMatt commented 1 year ago

If each ring had the same number of samples then the HEALPix FFT should just reduce to a simple jnp.fft call which should solve your compilation time issue. What effect this would have on the error I'm not too sure, it's not something we've really looked at yet. There may be other ways around this compile time issue, and I'm happy to look into it when I next get a chance :)

jasonmcewen commented 1 year ago

Thanks for the comments @EiffL. As @CosmoMatt said, the slow jit compile time for HEALPix is an issue with the HEALPix sampling that we're aware of and need to look into further. This is in fact a general problem with HEALPix, which is also present for CPU HEALPix implementations too. In those implementations one would ideally like to plan an FFT to optimise efficiency (e.g. FFTW allows you to plan FFTs to optimise them), however since the size of each FFT differs for each ring, you cannot plan a single efficient FFT. Therefore other codes typically simply use an estimated FFT, which is perhaps not optimal. But planning each FFT would be a signifcant overhead and there is not much of a performance hit in using estimated FFTs. We face a very similar issue here, again due to the varying number of samples on each ring, but it is more acute since it has a significant impact on jit compile time. We need to look into this some more. This is fairly high priority but we have some other more pressing commitments for the next couple of weeks at least so not sure we'll be able to look into this in detail until before then. Any insight that you can provide @EiffL would be very welcome!