ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
15.01k stars 856 forks source link

Up to 10x faster scatter. #709

Closed vj-krish closed 3 months ago

vj-krish commented 3 months ago

Proposed changes

Add specialization for 1d index tensors.

Github Issue #506

Checklist

Put an x in the boxes that apply.

vj-krish commented 3 months ago

Performance on M2Max (38-Core GPU) MBP.

Before

X (1000000, 64), Indices [(1000000,)] MLX: 4.529ms PyTorch: 10.208ms

X (1000000, 64), Indices [(1000000,)] MLX: 4.630ms PyTorch: 11.018ms

X (100000, 64), Indices [(100000,)] MLX: 1.987ms PyTorch: 1.091ms

X (1000000,), Indices [(1000000,)] MLX: 0.248ms PyTorch: 0.290ms

X (20000000,), Indices [(20000000,)] MLX: 1.110ms PyTorch: 3.654ms

X (20000000,), Indices [(20000000,)] MLX: 5.488ms PyTorch: 7.152ms

X (1000000, 64), Indices [(1000000,)] MLX: 4.537ms PyTorch: 10.634ms

X (10000000, 64), Indices [(10000000,)] MLX: 43.065ms PyTorch: 100.996ms

X (1000, 10000, 64), Indices [(1000,)] MLX: 58.935ms PyTorch: 108.575ms

X (10000, 100, 100, 21), Indices [(10000,)] MLX: 235.824ms PyTorch: 336.351ms

X (1000, 10), Indices [(1000,), (1000,)] MLX: 0.438ms PyTorch: 0.251ms

After

X (1000000, 64), Indices [(1000000,)] MLX: 0.837ms PyTorch: 10.298ms

X (1000000, 64), Indices [(1000000,)] MLX: 1.242ms PyTorch: 11.068ms

X (100000, 64), Indices [(100000,)] MLX: 1.740ms PyTorch: 1.074ms

X (1000000,), Indices [(1000000,)] MLX: 0.231ms PyTorch: 0.284ms

X (20000000,), Indices [(20000000,)] MLX: 0.964ms PyTorch: 3.655ms

X (20000000,), Indices [(20000000,)] MLX: 5.471ms PyTorch: 7.195ms

X (1000000, 64), Indices [(1000000,)] MLX: 0.829ms PyTorch: 10.707ms

X (10000000, 64), Indices [(10000000,)] MLX: 6.666ms PyTorch: 101.062ms

X (1000, 10000, 64), Indices [(1000,)] MLX: 10.013ms PyTorch: 108.335ms

X (10000, 100, 100, 21), Indices [(10000,)] MLX: 22.924ms PyTorch: 335.587ms

X (1000, 10), Indices [(1000,), (1000,)] MLX: 0.422ms PyTorch: 0.263ms

TristanBilot commented 3 months ago

Wow, impressive work again @vj-krish 💪 I'll give it a try on real datasets soon!

awni commented 3 months ago

Looks awesome!! @vj-krish can you check the failed tests?

vj-krish commented 3 months ago

Looks awesome!! @vj-krish can you check the failed tests?

Pushed a fix, should pass now.

vj-krish commented 3 months ago

Looks really nice!! Thanks for the exceptional speedup. Could you check the comments in line and let me know what you think?

Thanks for the review! I've responded to all of them. PTAL, thanks!

awni commented 3 months ago

I think maybe what you want is all the arrays are row_contiguous (e.g. inputs[i].flags().row_contiguous == True).

Did you change this?

awni commented 3 months ago

I think this is broken for col contiguous inputs but maybe I am wrong about that. Might be good to add a test where you send in transposed inputs.

vj-krish commented 3 months ago

I think maybe what you want is all the arrays are row_contiguous (e.g. inputs[i].flags().row_contiguous == True).

Did you change this?

Not yet. Was waiting for the thread to converge before addressing it.

awni commented 3 months ago

I see. Ok well the indices are 1D so the check you have now is fine but it's simpler/faster to just check row_contiguous

For the update if it is multidimensional and col contiguous I think (?) it is a bug

vj-krish commented 3 months ago

I see. Ok well the indices are 1D so the check you have now is fine but it's simpler/faster to just check row_contiguous

For the update if it is multidimensional and col contiguous I think (?) it is a bug

Pushed a fix for this.

vj-krish commented 3 months ago

Just a couple nits. Do you mind checking them? Then we can merge!

Looks great, thanks!!

Thanks, addressed them.

awni commented 3 months ago

Thanks @vj-krish , I will merge as soon as the tests clear