Closed Magwos closed 6 months ago
Hi @Magwos, great to see that you've been testing out s2fft
! This kind of decrease in performance is expected on CPU as the algorithms which we implement are tailored to be as parallel as possible, so as to really push performance on GPU devices. This is in contrast to existing C methods which pick up algorithms which are highly efficient when computing things in series.
That being said, certain components of the HEALPix spherical harmonic transform (the longitudinal Fourier transforms) do not support algorithms which can really push GPU performance (due to ragged arrays), so for HEALPix specifically the GPU transforms will be less performant than say MW sampled transforms.
Judging from our benchmarking numbers at L = 2*nside = 128
you should expect something on the order of 200 microseconds, which would be comparable to the healpy numbers you were finding.
This kind of decrease in performance is expected on CPU as the algorithms which we implement are tailored to be as parallel as possible, so as to really push performance on GPU devices.
If I read the timings in your paper correctly, your GPU timings are fairly similar to what you'd get using SSHT with the ducc
backend on a single CPU core. There's definitely still significant optimization potential in s2fft
, but I expect that tuning will be difficult, since you have to rely heavily on JAX's JIT compiler...
That may depend on the spin number @mreineck, I'm not too familiar with the ducc
backend? Totally agree, still a lot of optimisation to do for s2fft
the difficulty is finding time (and the JIT complications of course!)
Unfortunately I can't get SSHT to compile at the moment due to issues with Conan, but running my old libsharp
benchmark should also give pretty reliable results. I'm comparing against the MW-sampling, L=8192 benchmark in the paper, but I use spin=42
instead of your spin=0
since I know that your algorithm has no special case for spin zero, which would put me at an unfair advantage. I'm also using a Gauss-Legendre grid, which should be fine, as it has practically the same number of rings as MW.
Single thread run:
martin@marvin:~/codes/libsharp$ ./sharp2_testsuite test gauss 8191 -1 -1 16384 42 1
+------------+
| sharp_test |
+------------+
Detected hardware architecture: default
Supported vector length: 4
OpenMP active, but running with 1 thread only.
MPI: not supported by this binary
Testing map analysis accuracy.
spin=42
ntrans=1
lmax: 8191, mmax: 8191
Gauss-Legendre grid, nlat=8192, nlon=16384
Best of 2 runs
wall time for alm2map: 67.809210s
Performance: 32.968722GFLOPs/s
wall time for map2alm: 52.554365s
Performance: 42.538484GFLOPs/s
component 0: rms 8.623141e-13, maxerr 3.095791e-11
component 1: rms 8.613094e-13, maxerr 3.047345e-11
Memory high water mark: 3334.17 MB
Memory overhead: 262.04 MB (7.86% of working set)
(null) gauss 42 4 1 1 8191 8191 8192 16384 6.78e+01 32.97 5.26e+01 42.54 3334.17 7.86 8.62e-13 3.10e-11
8- thread run (using all physical cores):
martin@marvin:~/codes/libsharp$ ./sharp2_testsuite test gauss 8191 -1 -1 16384 42 1
+------------+
| sharp_test |
+------------+
Detected hardware architecture: default
Supported vector length: 4
OpenMP active: max. 8 threads.
MPI: not supported by this binary
Testing map analysis accuracy.
spin=42
ntrans=1
lmax: 8191, mmax: 8191
Gauss-Legendre grid, nlat=8192, nlon=16384
Best of 2 runs
wall time for alm2map: 10.548920s
Performance: 211.925298GFLOPs/s
wall time for map2alm: 8.659435s
Performance: 258.167301GFLOPs/s
component 0: rms 8.623141e-13, maxerr 3.095791e-11
component 1: rms 8.613094e-13, maxerr 3.047345e-11
Memory high water mark: 3344.55 MB
Memory overhead: 272.43 MB (8.15% of working set)
(null) gauss 42 4 8 1 8191 8191 8192 16384 1.05e+01 211.93 8.66e+00 258.17 3344.55 8.15 8.62e-13 3.10e-11
So the single-thread run is approximately twice as fast as s2fft
on a single GPU (since you are measuring 240s for a round-trip)
Thanks a lot for all the answers !
So if I understand correctly, with the use I intended for those transforms (Gibbs sampling with JAX), in order to use s2fft (in its current status!), I would need to project/deproject between Healpix pixelisation and another sampling scheme such as MW ? Or perform everything within another pixelisation scheme to be more efficient and avoid potential projection errors?
@Magwos so this is going to depend on a few things specific to your particular problem. So I don't steer you wrong could you provide a few details:
Depending on these answers S2FFT
may (or may not) be the better suited package.
@mreineck very interesting, the ducc
backend seems super optimised! With the current setup our mw
and mwss
sampled algorithms have to be upsampled to twice the number of theta samples which slows them down on a forward pass by a factor of 2. So switching to GL sampling, supported in v1.0.2 released earlier today, will presumably be comparable to ducc
.
Out of interest, what is the asymptotic complexity you could expect with ducc
given an infinite number of threads? For example the analogue for S2FFT
is O(L) scaling with access to sufficiently many GPU devices. Also does ducc
recover linear scaling with the number of threads? It looks like a little less than that from the results you linked.
With the current setup our mw and mwss sampled algorithms have to be upsampled to twice the number of theta samples which slows them down on a forward pass by a factor of 2.
Interesting, thanks for pointing that out! I had thought that the SSHT implementation avoided this somehow, and implicitly assumed that it was still the case for s2fft
...
There is a way to avoid transforming so many rings (though you still need to upsample temporarily), but it involves a lot of FFT magic; if you are interested, see Appendix A of https://arxiv.org/abs/2304.10431, but I'm not sure whether that can help you for s2fft
.
So switching to GL sampling, supported in v1.0.2 released earlier today, will presumably be comparable to ducc.
Well, you reach parity between a single Zen 2 core and an A100 :-)
Out of interest, what is the asymptotic complexity you could expect with ducc given an infinite number of threads? For example the analogue for S2FFT is O(L) scaling with access to sufficiently many GPU devices. Also does ducc recover linear scaling with the number of threads? It looks like a little less than that from the results you linked.
Parallelization is done over m
, so the the work per thread is O(L**2). (This is for the Legendre transform part, the FFT is parallelized over rings.) Therefore ducc
won't scale to a number of threads that's larger than L
, but that feels like an academic limitation :-) If it turns out to be necessary, additional parallelization over rings could be introduced.
Scaling is not perfectly linear, mostly due to the fact that memory bandwidth of a compute node is a bottleneck; beyond a certain number of threads, the CPUs have to wait in line for memory accesses. It's not really bad but certainly noticeable.
For reference: this is (to my knowledge) the current state of the art in GPU SHTs: https://ui.adsabs.harvard.edu/abs/2021EGUGA..2313680S/abstract No Healpix support though...
Thanks for the discussion.
Hi !
I've been testing the package and noticed quite long computation time for spherical harmonics compared to Healpy's one, with at least a factor 10 difference, and I am wondering if this is expected or not.
I am running everything on my Mac M1, after installing the version 1.0.2 of s2fft (from GitHub) and Pytorch.
The set-up of my arrays is :
Then, I have the following times:
First call:
Second call:
I have similar times for
alm2map
:First call:
Second call:
Is there something I'm doing wrong or is it expected on CPU that for such nside the transforms would take much more time than Healpy, aside for the compilation time ? Maybe related to the issue #142 ?
I also tried for
nside = 512
and :forward_jax
to 46s and 17s for first and further callsinverse_jax
to ~6m and 18s for first and further calls