Closed arnold-yan closed 3 days ago
Same benchmark on an M2 Ultra
average time of Pytorch: 7.20261025428772
average time of MLX: 2.34059739112854
On M2Pro
On M3 Max
Thanks for the benchmarks everyone! There is clearly an unexpected performance cliff on M1 machines here as MLX is substantially faster on M2+. We'll need to take a deeper look at that to figure out where it's coming from.
On M1
M3 Max: average time of MLX: 2.939736843109131 average time of Pytorch: 5.9829957485198975
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to nn.Upsample
. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.
A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")
The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to
nn.Upsample
. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")
The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
Hi @awni, thank you for figuring that out! It indeed a mistake that I intended to use "linear" here. However, when I tried to change to "linear" and run the test again on my M1 Pro MacBook Pro, I found that the running time even increased.
average time of MLX: 275.73419642448425 ms
I will find a M1 Max machine to verify this again.
@arnold-yan you're right the benchmark is slower with linear 😓 . I had a mistake. Let me keep digging.
Hi @arnold-yan https://github.com/ml-explore/mlx/pull/1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:
Machine | MLX | PT |
---|---|---|
M1 Max | 4.615 | 11.93 |
M3 Max | 1.938 | 10.77 |
Hi @arnold-yan #1541 should improve your benchmark a lot. I ran it on an M1 Max and M3 Max and the numbers are now:
Machine MLX PT M1 Max 4.615 11.93 M3 Max 1.938 10.77
Thanks @awni! Let me test on this PR. Now running time on M1 Pro is:
average time of MLX: 7.930095672607422 ms.
Good job!
@arnold-yan. I took a look at this benchmark.
The performance issue turns out to be from the gradient of the second call to
nn.Upsample
. It uses nearest neighbor interpolation by default. The forward is a gather under the hood and the backward is a scatter add. The scatter add is very inefficient on M1 in this case because it uses atomics and there are a lot of collisions to the same element.A simple fix is to use linear interpolation (which I believe is what you do with Pytorch anyway).
Changing to:
upsample = nn.Upsample(scale_factor=(h_scale, w_scale), mode="linear")
The benchmark runs in 2.89 ms on my M1 max compared to PyTorch 13.9 ms.
By the way, I also tried this change on your PR on M1 Pro. Now the MLX time change to
average time of MLX: 18.393556356430054
average time of Pytorch: 15.363117218017578
It is a little bit slower than PyTorch. Does it mean we still have potential to accelerate MLX performance for linear mode?
Describe the bug Recently I profiled the neural network layer performance from MLX and compared with PyTorch. I found that although MLX forwarding is consistently faster than PyTorch, in some chips (M1 Pro, M1 Max), PyTorch is much faster (3x~6x) for convolution forward + backward. While in some chips such as M3 Max, MLX is faster than PyTorch.
To Reproduce To reproduce this, I have two minimal examples. The networks just have several convolution layers. You may try these two scripts to verify the performance.
time_pytorch_mlx.zip