astro-informatics / s2fft

Differentiable and accelerated spherical transforms with JAX
https://astro-informatics.github.io/s2fft
MIT License
130 stars 9 forks source link

Computation time spherical harmonics vs Healpy #192

Closed Magwos closed 6 months ago

Magwos commented 7 months ago

Hi !

I've been testing the package and noticed quite long computation time for spherical harmonics compared to Healpy's one, with at least a factor 10 difference, and I am wondering if this is expected or not.

I am running everything on my Mac M1, after installing the version 1.0.2 of s2fft (from GitHub) and Pytorch.

The set-up of my arrays is :

nside = 64
npix = 12*nside**2
lmax = 2*nside

key = jax.random.PRNGKey(0)

map_in = jax.random.normal(key, shape=(npix,))
map_in_np = np.array(map_in)

Then, I have the following times:

flm_HP = hp.map2alm(map_in_np, lmax=lmax, iter=0)

CPU times: user 1.8 ms, sys: 1.45 ms, total: 3.25 ms Wall time: 1.89 ms

sampling = "healpix"
precomps = s2fft.generate_precomputes(lmax+1, 0, sampling, nside, True)

CPU times: user 144 ms, sys: 5.69 ms, total: 149 ms Wall time: 149 ms

flm_check = s2fft.forward_jax(
    map_in,
    lmax+1,
    spin=0,
    nside=nside,
    sampling=sampling,
    reality=True,
    precomps=precomps,
    spmd=False,
    )

First call:

CPU times: user 2.09 s, sys: 90.2 ms, total: 2.18 s Wall time: 2.12 s

Second call:

CPU times: user 46.8 ms, sys: 3.21 ms, total: 50.1 ms Wall time: 40.9 ms

I have similar times for alm2map:

map_output_HP = hp.alm2map(flm_HP, lmax=lmax, nside=nside)

CPU times: user 1.46 ms, sys: 415 µs, total: 1.88 ms Wall time: 708 µs

precomps_inverse = s2fft.generate_precomputes(lmax+1, 0, sampling, nside, False)

CPU times: user 152 ms, sys: 8.18 ms, total: 160 ms Wall time: 182 ms

map_output_S2FFT = s2fft.inverse(
        flm_check,
        lmax+1,
        spin=0,
        nside=nside,
        sampling=sampling,
        method="jax",
        reality=True,
        precomps=precomps_inverse,
        spmd=False,
    )

First call:

CPU times: user 5.88 s, sys: 169 ms, total: 6.05 s Wall time: 5.97 s

Second call:

CPU times: user 55.5 ms, sys: 3.1 ms, total: 58.6 ms Wall time: 53 ms

Is there something I'm doing wrong or is it expected on CPU that for such nside the transforms would take much more time than Healpy, aside for the compilation time ? Maybe related to the issue #142 ?

I also tried for nside = 512 and :

CosmoMatt commented 7 months ago

Hi @Magwos, great to see that you've been testing out s2fft! This kind of decrease in performance is expected on CPU as the algorithms which we implement are tailored to be as parallel as possible, so as to really push performance on GPU devices. This is in contrast to existing C methods which pick up algorithms which are highly efficient when computing things in series.

That being said, certain components of the HEALPix spherical harmonic transform (the longitudinal Fourier transforms) do not support algorithms which can really push GPU performance (due to ragged arrays), so for HEALPix specifically the GPU transforms will be less performant than say MW sampled transforms.

Judging from our benchmarking numbers at L = 2*nside = 128 you should expect something on the order of 200 microseconds, which would be comparable to the healpy numbers you were finding.

mreineck commented 7 months ago

This kind of decrease in performance is expected on CPU as the algorithms which we implement are tailored to be as parallel as possible, so as to really push performance on GPU devices.

If I read the timings in your paper correctly, your GPU timings are fairly similar to what you'd get using SSHT with the ducc backend on a single CPU core. There's definitely still significant optimization potential in s2fft, but I expect that tuning will be difficult, since you have to rely heavily on JAX's JIT compiler...

CosmoMatt commented 7 months ago

That may depend on the spin number @mreineck, I'm not too familiar with the ducc backend? Totally agree, still a lot of optimisation to do for s2fft the difficulty is finding time (and the JIT complications of course!)

mreineck commented 7 months ago

Unfortunately I can't get SSHT to compile at the moment due to issues with Conan, but running my old libsharp benchmark should also give pretty reliable results. I'm comparing against the MW-sampling, L=8192 benchmark in the paper, but I use spin=42 instead of your spin=0 since I know that your algorithm has no special case for spin zero, which would put me at an unfair advantage. I'm also using a Gauss-Legendre grid, which should be fine, as it has practically the same number of rings as MW.

Single thread run:

martin@marvin:~/codes/libsharp$ ./sharp2_testsuite test gauss 8191 -1 -1 16384 42 1

+------------+
| sharp_test |
+------------+

Detected hardware architecture: default
Supported vector length: 4
OpenMP active, but running with 1 thread only.
MPI: not supported by this binary

Testing map analysis accuracy.
spin=42
ntrans=1
lmax: 8191, mmax: 8191
Gauss-Legendre grid, nlat=8192, nlon=16384
Best of 2 runs
wall time for alm2map: 67.809210s
Performance: 32.968722GFLOPs/s
wall time for map2alm: 52.554365s
Performance: 42.538484GFLOPs/s
component 0: rms 8.623141e-13, maxerr 3.095791e-11
component 1: rms 8.613094e-13, maxerr 3.047345e-11

Memory high water mark: 3334.17 MB
Memory overhead: 262.04 MB (7.86% of working set)
(null)       gauss      42 4  1   1   8191   8191   8192  16384 6.78e+01   32.97 5.26e+01   42.54   3334.17   7.86 8.62e-13 3.10e-11

8- thread run (using all physical cores):

martin@marvin:~/codes/libsharp$ ./sharp2_testsuite test gauss 8191 -1 -1 16384 42 1

+------------+
| sharp_test |
+------------+

Detected hardware architecture: default
Supported vector length: 4
OpenMP active: max. 8 threads.
MPI: not supported by this binary

Testing map analysis accuracy.
spin=42
ntrans=1
lmax: 8191, mmax: 8191
Gauss-Legendre grid, nlat=8192, nlon=16384
Best of 2 runs
wall time for alm2map: 10.548920s
Performance: 211.925298GFLOPs/s
wall time for map2alm: 8.659435s
Performance: 258.167301GFLOPs/s
component 0: rms 8.623141e-13, maxerr 3.095791e-11
component 1: rms 8.613094e-13, maxerr 3.047345e-11

Memory high water mark: 3344.55 MB
Memory overhead: 272.43 MB (8.15% of working set)
(null)       gauss      42 4  8   1   8191   8191   8192  16384 1.05e+01  211.93 8.66e+00  258.17   3344.55   8.15 8.62e-13 3.10e-11

So the single-thread run is approximately twice as fast as s2fft on a single GPU (since you are measuring 240s for a round-trip)

Magwos commented 7 months ago

Thanks a lot for all the answers !

So if I understand correctly, with the use I intended for those transforms (Gibbs sampling with JAX), in order to use s2fft (in its current status!), I would need to project/deproject between Healpix pixelisation and another sampling scheme such as MW ? Or perform everything within another pixelisation scheme to be more efficient and avoid potential projection errors?

CosmoMatt commented 7 months ago

@Magwos so this is going to depend on a few things specific to your particular problem. So I don't steer you wrong could you provide a few details:

Depending on these answers S2FFT may (or may not) be the better suited package.

CosmoMatt commented 7 months ago

@mreineck very interesting, the ducc backend seems super optimised! With the current setup our mw and mwss sampled algorithms have to be upsampled to twice the number of theta samples which slows them down on a forward pass by a factor of 2. So switching to GL sampling, supported in v1.0.2 released earlier today, will presumably be comparable to ducc.

Out of interest, what is the asymptotic complexity you could expect with ducc given an infinite number of threads? For example the analogue for S2FFT is O(L) scaling with access to sufficiently many GPU devices. Also does ducc recover linear scaling with the number of threads? It looks like a little less than that from the results you linked.

mreineck commented 7 months ago

With the current setup our mw and mwss sampled algorithms have to be upsampled to twice the number of theta samples which slows them down on a forward pass by a factor of 2.

Interesting, thanks for pointing that out! I had thought that the SSHT implementation avoided this somehow, and implicitly assumed that it was still the case for s2fft ... There is a way to avoid transforming so many rings (though you still need to upsample temporarily), but it involves a lot of FFT magic; if you are interested, see Appendix A of https://arxiv.org/abs/2304.10431, but I'm not sure whether that can help you for s2fft.

So switching to GL sampling, supported in v1.0.2 released earlier today, will presumably be comparable to ducc.

Well, you reach parity between a single Zen 2 core and an A100 :-)

Out of interest, what is the asymptotic complexity you could expect with ducc given an infinite number of threads? For example the analogue for S2FFT is O(L) scaling with access to sufficiently many GPU devices. Also does ducc recover linear scaling with the number of threads? It looks like a little less than that from the results you linked.

Parallelization is done over m, so the the work per thread is O(L**2). (This is for the Legendre transform part, the FFT is parallelized over rings.) Therefore ducc won't scale to a number of threads that's larger than L, but that feels like an academic limitation :-) If it turns out to be necessary, additional parallelization over rings could be introduced. Scaling is not perfectly linear, mostly due to the fact that memory bandwidth of a compute node is a bottleneck; beyond a certain number of threads, the CPUs have to wait in line for memory accesses. It's not really bad but certainly noticeable.

mreineck commented 7 months ago

For reference: this is (to my knowledge) the current state of the art in GPU SHTs: https://ui.adsabs.harvard.edu/abs/2021EGUGA..2313680S/abstract No Healpix support though...

jasonmcewen commented 6 months ago

Thanks for the discussion.