ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.59k stars 1.02k forks source link

improvements to scatter / gather #1541

Closed awni closed 4 weeks ago

awni commented 1 month ago

Improves backward of upsampling and mostly closes #1313

Machine MLX Pre MLX Post PT
M1 Max 39.89 4.615 11.93
M3 Max 2.539 1.938 10.77

M1 Max Benchmarks

Benchmark Pre Post
ND Index Gather 4.364 (ms) 2.820 (ms)
2D Bench, Factor 0.25 61.164 (ms) 29.541 (ms)
2D Bench, Factor 0.5 48.299 (ms) 16.105 (ms)
2D Bench, Factor 1 45.658 (ms) 13.310 (ms)
2D Bench, Factor 2 44.240 (ms) 10.184 (ms)
2D Bench, Factor 4 44.050 (ms) 9.956 (ms)
2D Bench, Factor 8 44.013 (ms) 10.467 (ms)
2D Bench, Factor 16 44.074 (ms) 11.148 (ms)
2D Bench, Factor 32 44.459 (ms) 12.551 (ms)
2D Bench, Factor 128 115.288 (ms) 47.017 (ms)
1D Bench non-contiguous 34.177 (ms) 2.282 (ms)

M3 Max Benchmarks

Benchmark Pre Post
ND Index Gather 2.138 (ms) 1.567 (ms)
2D Bench, Factor 0.25 41.247 (ms) 22.860 (ms)
2D Bench, Factor 0.5 27.324 (ms) 8.087 (ms)
2D Bench, Factor 1 23.801 (ms) 4.532 (ms)
2D Bench, Factor 2 22.912 (ms) 3.309 (ms)
2D Bench, Factor 4 22.696 (ms) 3.022 (ms)
2D Bench, Factor 8 22.665 (ms) 2.974 (ms)
2D Bench, Factor 16 22.661 (ms) 2.984 (ms)
2D Bench, Factor 32 22.656 (ms) 2.987 (ms)
2D Bench, Factor 128 22.631 (ms) 2.985 (ms)
1D Bench non-contiguous 17.672 (ms) 1.329 (ms)