ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.78k stars 898 forks source link

[Feature] Support fft based convolution #811

Open adonath opened 4 months ago

adonath commented 4 months ago

It would be nice to have FFT based convolution supported in mlx. FFT bases convolution shows much better performance for large images / arrays and kernels. The FFT building blocks are already supported in mlx, so it is mostly a matter of combining them to a convolution operation.

sebblanchet commented 4 months ago

@awni

wouldn't mind looking into implementing this, what do you think ?

awni commented 4 months ago

One challenge here is that FFT is not yet supported on the GPU (in Metal). So you could use it but on the CPU it would almost certainly be much slower than our GPU convolution.

Also I think FFT-based convolution is more of an implementation detail. If there are some sizes that are slow for you, please share any benchmarks. We can then figure out the best way to make them faster (which may or may not require an FFT-based convolution).

adonath commented 4 months ago

Thanks @awni and @sebblanchet! I did a quick implementation of a FFT based convolution in MLX:

def _centered(arr, newshape):
    newshape = mx.array(newshape)
    currshape = mx.array(arr.shape)

    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k].item(), endind[k].item()) for k in range(len(endind))]
    return arr[tuple(myslice)]

def convolve_fft(image, kernel, stream):
    """Convolve FFT for torch tensors"""
    image_2d, kernel_2d = image[0, 0], kernel[0, 0]

    shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]

    image_ft = mx.fft.rfft2(image, s=shape, stream=stream)
    kernel_ft = mx.fft.rfft2(kernel, s=shape, stream=stream)
    result = mx.fft.irfft2(image_ft * kernel_ft, s=shape, stream=stream)
    return _centered(result, image.shape)

I also did a simple benchmark. It uses a random image of size 1024x1024 and varying kernel sizes. It compares mx.conv2d on the GPU and CPU respectively, the FFT based algorithm from above and for comparison Scipy's FFT convolution implementation. The result is the following:

mlx-conv-mini-benchmark

I think it follows exactly the expectation:

In general I think it is still worth to have an FFT based convolution. For NNs with small kernels, there is no point. But there are many scientific applications that rely on large kernels (think of cross-correlations, convolution with pathological point spread functions, etc.)

I think it is worth re-opening.

awni commented 4 months ago

Ok sounds good! Thanks for the benchmarks, that's really interesting!

awni commented 4 months ago

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

adonath commented 4 months ago

Thanks for re-opening @awni!

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

This is what Scipy has too, see https://github.com/scipy/scipy/blob/v1.12.0/scipy/signal/_signaltools.py#L1161 There is the option to measure or to actually compute the flops. Measuring only makes sense for repeated convolutions, but gives probably the most accurate results for arbitrary architectures. Looking at the Scipy code, it seems that computing the flops is maybe too complex. Or is there a general way to predict flops for mlx operations? (would be nice to have...)

In general the performance of MLX operations is probably much more predictable across the more homogeneous M architectures. So there could be a third option by just parametrizing the scaling laws based on empirical benchmarks or something similar...

adonath commented 4 months ago

Here is the gist with the code for the benchmark: https://gist.github.com/adonath/3f16b30498c60f25cf1349792c15283c