ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.99k stars 984 forks source link

[Performance] PyTorch (MPS) is faster than MLX in backward of convolution layer #1313

Closed arnold-yan closed 3 days ago

arnold-yan commented 2 months ago

Describe the bug Recently I profiled the neural network layer performance from MLX and compared with PyTorch. I found that although MLX forwarding is consistently faster than PyTorch, in some chips (M1 Pro, M1 Max), PyTorch is much faster (3x~6x) for convolution forward + backward. While in some chips such as M3 Max, MLX is faster than PyTorch. image

To Reproduce To reproduce this, I have two minimal examples. The networks just have several convolution layers. You may try these two scripts to verify the performance.

time_pytorch_mlx.zip

awni commented 2 months ago

Same benchmark on an M2 Ultra

average time of Pytorch: 7.20261025428772
average time of MLX: 2.34059739112854
abdussamettrkr commented 2 months ago

On M2Pro

alwint3r commented 2 months ago

On M3 Max

awni commented 2 months ago

Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look at that to figure out where it's coming from.

pyvadev commented 1 month ago

On M1

jrp2014 commented 1 month ago

M3 Max: average time of MLX: 2.939736843109131 average time of Pytorch: 5.9829957485198975

awni commented 6 days ago

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

arnold-yan commented 5 days ago

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

Hi @awni, thank you for figuring that out! It indeed a mistake that I intended to use "linear" here. However, when I tried to change to "linear" and run the test again on my M1 Pro MacBook Pro, I found that the running time even increased.

average time of MLX: 275.73419642448425 ms

I will find a M1 Max machine to verify this again.

awni commented 5 days ago

@arnold-yan you're right the benchmark is slower with linear 😓 . I had a mistake. Let me keep digging.

awni commented 4 days ago

Hi @arnold-yan https://github.com/ml-explore/mlx/pull/1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:

Machine MLX PT
M1 Max 4.615 11.93
M3 Max 1.938 10.77
arnold-yan commented 3 days ago

Hi @arnold-yan #1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:

Machine MLX PT M1 Max 4.615 11.93 M3 Max 1.938 10.77

Thanks @awni! Let me test on this PR. Now running time on M1 Pro is:

average time of MLX: 7.930095672607422 ms.

Good job!

arnold-yan commented 3 days ago

@arnold-yan. I took a look at this benchmark.

The performance issue turns out to be from the gradient of the second call to nn.Upsample. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.

A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).

Changing to:

        upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")

The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.

By the way, I also tried this change on your PR on M1 Pro. Now the MLX time change to

average time of MLX: 18.393556356430054
average time of Pytorch: 15.363117218017578

It is a little bit slower than PyTorch. Does it mean we still have potential to accelerate MLX performance for linear mode?