flatironinstitute / finufft

Non-uniform fast Fourier transform library of types 1,2,3 in dimensions 1,2,3
Other
285 stars 72 forks source link

An Idea for performance and memory improvement #293

Open paquiteau opened 1 year ago

paquiteau commented 1 year ago

Hi there, I am doing a PhD on (functional) MRI Reconstruction, and found some possible optimization for the algorithm of (cu)finufft.

This happens in the call to FFTW, in the case of type 1 (Non Uniform to Uniform Points). Instead of computing the full oversampled grid (with generally osf = 2, so 8x more memory in 3D) the internal call to FFTW can be planned to performed a pruned transform (https://www.fftw.org/pruned.html) . Basically The last step of the butterfly has to be done by hand (and only one half is computed). If osf ≠ 2, I don't think its worth it (the butterfly cannot be split evenly). This would also require the deconvolveshuffle?d functions to be modified.

From my understanding this would not apply to type 2 (uniform to non-uniform), because you want the fine frequency grid to do the interpolation.

Unfortunately, I don't have the time nor the competences to propose an implementation. Yet, I think this could contribute to reduce the memory footprint and increase performances. Let me know if this idea spark anything.

(BTW, Thank you for the great work on (cu)finufft, I could not do my PhD without it)

ahbarnett commented 1 year ago

Dear Pierre-Antoine,

Thank-you for the kind words - please give us a link to your project (if you want) and tell us by what factor it sped up your code (how many iterations do you use in MRI recon?) - we need to collect the user success stories :)

As you may see, we are actively unifying cufinufft into finufft, and hope to improve interfaces and add type 3, and maybe upsampfac=1.25 (which would help you).

We have thought about pruned FFT, but don't believe there is much speed-up available for a factor of 2 per dim. I think the verdict is that their advantage only kicks in for large factors per dim. For example, for 3d1 with upsampfac=2, you'd have to sweep through the fine grid 8 times, doing size-N^3 each time. This won't be faster than a simple (2N)^3 then discarding 7/8 of the output. One reason is the repeated sweep of input RAM

On Sun, Jun 18, 2023 at 11:29 AM Pierre-Antoine Comby < @.***> wrote:

Hi there, I am doing a PhD on (functional) MRI Reconstruction, and found some possible optimization for the algorithm of (cu)finufft.

This happens in the call to FFTW, in the case of type 1 (Non Uniform to Uniform Points). Instead of computing the full oversampled grid (with generally osf = 2, so 8x more memory in 3D) the internal call to FFTW can be planned to performed a pruned transform ( https://www.fftw.org/pruned.html) . Basically The last step of the butterfly has to be done by hand (and only one half is computed). If osf ≠ 2, I don't think its worth it (the butterfly cannot be split evenly). This would also require the deconvolveshuffle?d functions to be modified.

From my understanding this would not apply to type 2 (uniform to non-uniform), because you want the fine frequency grid to do the interpolation.

Unfortunately, I don't have the time nor the competences to propose an implementation. Yet, I think this could contribute to reduce the memory footprint and increase performances. Let me know if this idea spark anything.

(BTW, Thank you for the great work on (cu)finufft, I could not do my PhD without it)

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/finufft/issues/293, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNZRSUX7JRUADVW23GEICDXL4ND5ANCNFSM6AAAAAAZK6VINI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

-- *-------------------------------------------------------------------~^`^~._.~' |\ Alex Barnett Center for Computational Mathematics, Flatiron Institute | \ http://users.flatironinstitute.org/~ahb 646-876-5942

paquiteau commented 1 year ago

Hi,

The core of my work relies on mind-inria/mri-nufft, where we provide interfaces for all the NUFFT librairies out there, including (cu)finufft, which is the most stable/actively maintained/fastest. In MRI the NUFFT is wrapped around with multi-coil (using sensitivity maps) and density compensation[^1]. The coil dimension is essentially the batch dimension (with typically N_coils = 32). The case of functional MRI is even more demanding because multiple volumes (≈ 400 volume of typical size 192x192x128, with NU_points≈150k) are acquired, potentially with a different sampling pattern for each, and for compressed sensing based reconstruction where temporal information is shared, well, you need to setup/execute/destroy (when there is not enough memory) that many plans (I am not completely there yet, but if you want more detail about those methods, everything lies in paquiteau/pysap-fmri.

Regarding the pruned FFT, I am not the only to have think about actually. People working on BART have proposed a similar idea in an ISMRM Abstract (sadly not in open access), and a pointer to the code is here. In fact BART propose its own implementation of the NUFFT, I haven't had the chance to test/benchmark it however.

I guess its a matter of how bad the stride would kill the performance benefit, if in the end the wall time remains the same and we get to use less memory, that's still a win (especially for GPU usage).

I am looking forward for the future development of (cu)finufft and will happily battle-test it,

Pierre-Antoine

[^1]: Actually there is some trouble for estimating the density compensation vector with finufft, but that might deserve its own issue

mreineck commented 1 year ago

There may be other ways than pruning to speed up the FFT. In a unifom-to-nonuniform 2D NUFFT with, say, an oversampling factor of 2, you only need to transform over the first axis for half of the array, since the other half of the array contains vectors that are identically zero. This saves 25% of FFT cost without having to use complicated hand-crafted algorithms. In 3D, you have to transform a quarter of the array along the first axis and half of the array along the second axis, saving an even larger fraction of FFT time.

In the nonuniform-to-uniform direction things are not as obvious, but the same amount of time can be saved,since you don't need the accurate results on the entire oversamped grid, just in a part of it.

I'm using this approach pretty successfully in my own code.

ahbarnett commented 1 year ago

Dear Pierre-Antoine and Martin,

Thank-you for the details - that's exciting. Re BART, I already tested it (CPU code at least), and you can read about my adventures getting it working, and the resulting performance as the * symbols in Figs. 6.3 and 6.4, of https://arxiv.org/abs/1808.06736 Ie, we beat them by at least 10x at the same accuracy :)

Martin, that's a great idea - I had not thought of dissecting the 2d and 3d FFTW calls like this. For upsampfac=2 it could be very useful (not worth if for upsampfac=1.25), and you'd write a custom 2d and 3d FFT call composed of many 1d transforms, just the subsets needed. (I'm assuming that wouldn't be slower than what FFTW does for 2d or 3d... I haven't looked at their code). It is good to have your eyes on our project :)

Cheers, Alex

On Mon, Jun 19, 2023 at 12:49 PM mreineck @.***> wrote:

There may be other ways than pruning to speed up the FFT. In a unifom-to-nonuniform 2D NUFFT with, say, an oversampling factor of 2, you only need to transform over the first axis for half of the array, since the other half of the array contains vectors that are identically zero. This saves 25% of FFT cost without having to use complicated hand-crafted algorithms. In 3D, you have to transform a quarter of the array along the first axis and half of the array along the second axis, saving an even larger fraction of FFT time.

In the nonuniform-to-uniform direction things are not as obvious, but the same amount of time can be saved,since you don't need the accurate results on the entire oversamped grid, just in a part of it.

I'm using this approach pretty successfully in my own code.

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/finufft/issues/293#issuecomment-1597483344, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNZRSTXRCPV7IDH4466EMLXMB7LHANCNFSM6AAAAAAZK6VINI . You are receiving this because you commented.Message ID: @.***>

-- *-------------------------------------------------------------------~^`^~._.~' |\ Alex Barnett Center for Computational Mathematics, Flatiron Institute | \ http://users.flatironinstitute.org/~ahb 646-876-5942

mreineck commented 1 year ago

and you'd write a custom 2d and 3d FFT call composed of many 1d transforms, just the subsets needed. (I'm assuming that wouldn't be slower than what FFTW does for 2d or 3d... I haven't looked at their code).

For 2D it would probably be a set of three plans generated with fftw_plan_many_dft (two for the first axis, one for the second). Except for the planning overhead (and yes, this may be a problem), this should not cause any unexpected slowdowns. For 3D, it's seven plans.

paquiteau commented 1 year ago

To foster more discussion, here is the abstract I was referring to https://perso.crans.org/comby/ISMRM2023/ISMRM%202023.html (the abstracts are only correctly viewable as html pages, I scrapped it off). The BART NUFFT is indeed slower, but more memory efficient, due to their trick. Note that the abstract do not compare BART and cufinufft, probably to avoid some embarrassment.

The abstract does not refers to the strides problem you mention, the only one I can see is happening on the kernel (which is much smaller) rather than on the data. I think this tricks is complementary to the one of @mreineck , but I won't consider myself a black belt of FFT-jitsu, so I will be happy to have more insight on this.

mreineck commented 1 year ago

I agree that the two optimizations complementary, so they can be combined.

There is one more aspect that should be considered when doing large FFTs: if the strides of some array dimensions are critical (i.e. they are a multiple of 4096 bytes on most current CPUs), cache re-use will be extremely bad, and the FFT will be very slow. Sadly, this situation mostly turns up for FFT sizes that are considered optimal, i.e. large powers of 2. Some trickery with array strides can work around the problem, and it can give huge performance boosts.

Here is an example that needs to be run in ipython (because of the %timeit magic). Please note that this problem exists with all FFT implementations; I just chose ducc here, because this allows for a very short example code.

import numpy as np
import ducc0

shape=(4096,4096)

# unpadded test
a=np.zeros(shape, dtype=np.complex128)
print(a.shape)
print(a.strides)
%timeit ducc0.fft.c2c(a, axes=(0,), out=a)

# padded test, avoiding critical strides
a=ducc0.misc.make_noncritical(a)
print(a.shape)
print(a.strides)
a[()] = 0
%timeit ducc0.fft.c2c(a, axes=(0,), out=a)

Making arrays non-critical in finufft is not trivial, since multi-D data are assumed to be stored in compact form, so I have not tried to change this yet. When working with power-of-2 grids and oversampling factors of 2, it should give substantial speedups however.

ahbarnett commented 1 year ago

Hi Martin, Fascinating. I had noticed that 2^n sizes were sometimes slower than 5-smooth neighbors with FFTW. On my laptop on python 3.9 and conda-installed ducc0 your example gives:

(4096, 4096)
(65536, 16)
287 ms ± 573 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
(4096, 4096)
(65584, 16)
173 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

So, 65% faster.

In FINUFFT I disagree: one has a lot of freedom in choosing the FFT size, because the user data is copied in/out of the central part of a new array. Changing the routine next235even is all that would be required. Is there some recipe from make_noncritical we can borrow?

On Tue, Jun 20, 2023 at 10:45 AM mreineck @.***> wrote:

I agree that the two optimizations coplementary, so they can be combined.

There is one more aspect that should be considered when soing large FFTs: if the strides of some array dimensions are critical (i.e. they are a multiple of 4096 bytes on most current CPUs), cache re-use will be extremely bad, and the FFT will be very slow. Sadly, this situation mostly turns up for FFT sizes that are considered optimal, i.e. large powers of 2. Some trickery with array strides can work around the problem, and it can give huge performance boosts.

Here is an example that needs to be run in ipython (because of the %timeit magic). Please note that this problem exists with all FFT imlpementations; I just chose ducc here, because this allows for a very short example code.

import numpy as npimport ducc0 shape=(4096,4096)

unpadded testa=np.zeros(shape, dtype=np.complex128)print(a.shape)print(a.strides)%timeit ducc0.fft.c2c(a, axes=(0,), out=a)

padded test, avoiding critical stridesa=ducc0.misc.make_noncritical(a)print(a.shape)print(a.strides)a[()] = 0%timeit ducc0.fft.c2c(a, axes=(0,), out=a)

Making arrays non-critical in finufft is not trivial, since multi-D data are assumed to be stored in compact form, so I have not tried to change this yet. When working with power-of-2 grids and oversampling factors of 2, it should give substantial speedups however.

— Reply to this email directly, view it on GitHub https://github.com/flatironinstitute/finufft/issues/293#issuecomment-1598933701, or unsubscribe https://github.com/notifications/unsubscribe-auth/ACNZRSXG6HGQCNHYCH3NYPTXMGZQVANCNFSM6AAAAAAZK6VINI . You are receiving this because you commented.Message ID: @.***>

-- *-------------------------------------------------------------------~^`^~._.~' |\ Alex Barnett Center for Computational Mathematics, Flatiron Institute | \ http://users.flatironinstitute.org/~ahb 646-876-5942

mreineck commented 1 year ago

Ah, of course, you have the choice of just increasing the array dimensions a little bit, I forgot about that! make_noncritical does not change the actual array dimensions, it just adds minimal padding where necessary; as a method for obtaining actual FFT lengths this would be quite bad... You can update next235even such that it takes an additional argument, which is the unit stride in bytes along the direction for which you want the length. Then you just ignore all potential results where length*stride/4096 is an integer. You have to start from the last dimension (in C-ordering) and work your way forwards. Overall this approach might be slightly worse than the padding one, since you avoid large power-of-2 lengths, which are in principle really advantageous.