Closed vj-krish closed 3 months ago
X (1000000, 64), Indices [(1000000,)] MLX: 4.529ms PyTorch: 10.208ms
X (1000000, 64), Indices [(1000000,)] MLX: 4.630ms PyTorch: 11.018ms
X (100000, 64), Indices [(100000,)] MLX: 1.987ms PyTorch: 1.091ms
X (1000000,), Indices [(1000000,)] MLX: 0.248ms PyTorch: 0.290ms
X (20000000,), Indices [(20000000,)] MLX: 1.110ms PyTorch: 3.654ms
X (20000000,), Indices [(20000000,)] MLX: 5.488ms PyTorch: 7.152ms
X (1000000, 64), Indices [(1000000,)] MLX: 4.537ms PyTorch: 10.634ms
X (10000000, 64), Indices [(10000000,)] MLX: 43.065ms PyTorch: 100.996ms
X (1000, 10000, 64), Indices [(1000,)] MLX: 58.935ms PyTorch: 108.575ms
X (10000, 100, 100, 21), Indices [(10000,)] MLX: 235.824ms PyTorch: 336.351ms
X (1000, 10), Indices [(1000,), (1000,)] MLX: 0.438ms PyTorch: 0.251ms
X (1000000, 64), Indices [(1000000,)] MLX: 0.837ms PyTorch: 10.298ms
X (1000000, 64), Indices [(1000000,)] MLX: 1.242ms PyTorch: 11.068ms
X (100000, 64), Indices [(100000,)] MLX: 1.740ms PyTorch: 1.074ms
X (1000000,), Indices [(1000000,)] MLX: 0.231ms PyTorch: 0.284ms
X (20000000,), Indices [(20000000,)] MLX: 0.964ms PyTorch: 3.655ms
X (20000000,), Indices [(20000000,)] MLX: 5.471ms PyTorch: 7.195ms
X (1000000, 64), Indices [(1000000,)] MLX: 0.829ms PyTorch: 10.707ms
X (10000000, 64), Indices [(10000000,)] MLX: 6.666ms PyTorch: 101.062ms
X (1000, 10000, 64), Indices [(1000,)] MLX: 10.013ms PyTorch: 108.335ms
X (10000, 100, 100, 21), Indices [(10000,)] MLX: 22.924ms PyTorch: 335.587ms
X (1000, 10), Indices [(1000,), (1000,)] MLX: 0.422ms PyTorch: 0.263ms
Wow, impressive work again @vj-krish 💪 I'll give it a try on real datasets soon!
Looks awesome!! @vj-krish can you check the failed tests?
Looks awesome!! @vj-krish can you check the failed tests?
Pushed a fix, should pass now.
Looks really nice!! Thanks for the exceptional speedup. Could you check the comments in line and let me know what you think?
Thanks for the review! I've responded to all of them. PTAL, thanks!
I think maybe what you want is all the arrays are row_contiguous (e.g. inputs[i].flags().row_contiguous == True).
Did you change this?
I think this is broken for col contiguous inputs but maybe I am wrong about that. Might be good to add a test where you send in transposed inputs.
I think maybe what you want is all the arrays are row_contiguous (e.g. inputs[i].flags().row_contiguous == True).
Did you change this?
Not yet. Was waiting for the thread to converge before addressing it.
I see. Ok well the indices are 1D so the check you have now is fine but it's simpler/faster to just check row_contiguous
For the update if it is multidimensional and col contiguous I think (?) it is a bug
I see. Ok well the indices are 1D so the check you have now is fine but it's simpler/faster to just check
row_contiguous
For the update if it is multidimensional and col contiguous I think (?) it is a bug
Pushed a fix for this.
Just a couple nits. Do you mind checking them? Then we can merge!
Looks great, thanks!!
Thanks, addressed them.
Thanks @vj-krish , I will merge as soon as the tests clear
Proposed changes
Add specialization for 1d index tensors.
Github Issue #506
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes