issues
search
ml-explore
/
mlx
MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.59k
stars
1.02k
forks
source link
improvements to scatter / gather
#1541
Closed
awni
closed
4 weeks ago
awni
commented
1 month ago
Add contigous flag to indices for scatter / gather
For scatter: more work per thread when the ratio of indices / output is high
Add vector copy + possible donation for scatter
Specialization for contiguous update
Remove 1D specializations for scatter since they are subsumed by improvements to core kernel
Improves backward of upsampling and mostly closes #1313
Machine
MLX Pre
MLX Post
PT
M1 Max
39.89
4.615
11.93
M3 Max
2.539
1.938
10.77
M1 Max Benchmarks
Benchmark
Pre
Post
ND Index Gather
4.364 (ms)
2.820 (ms)
2D Bench, Factor 0.25
61.164 (ms)
29.541 (ms)
2D Bench, Factor 0.5
48.299 (ms)
16.105 (ms)
2D Bench, Factor 1
45.658 (ms)
13.310 (ms)
2D Bench, Factor 2
44.240 (ms)
10.184 (ms)
2D Bench, Factor 4
44.050 (ms)
9.956 (ms)
2D Bench, Factor 8
44.013 (ms)
10.467 (ms)
2D Bench, Factor 16
44.074 (ms)
11.148 (ms)
2D Bench, Factor 32
44.459 (ms)
12.551 (ms)
2D Bench, Factor 128
115.288 (ms)
47.017 (ms)
1D Bench non-contiguous
34.177 (ms)
2.282 (ms)
M3 Max Benchmarks
Benchmark
Pre
Post
ND Index Gather
2.138 (ms)
1.567 (ms)
2D Bench, Factor 0.25
41.247 (ms)
22.860 (ms)
2D Bench, Factor 0.5
27.324 (ms)
8.087 (ms)
2D Bench, Factor 1
23.801 (ms)
4.532 (ms)
2D Bench, Factor 2
22.912 (ms)
3.309 (ms)
2D Bench, Factor 4
22.696 (ms)
3.022 (ms)
2D Bench, Factor 8
22.665 (ms)
2.974 (ms)
2D Bench, Factor 16
22.661 (ms)
2.984 (ms)
2D Bench, Factor 32
22.656 (ms)
2.987 (ms)
2D Bench, Factor 128
22.631 (ms)
2.985 (ms)
1D Bench non-contiguous
17.672 (ms)
1.329 (ms)
Improves backward of upsampling and mostly closes #1313
M1 Max Benchmarks
M3 Max Benchmarks