ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.55k stars 1.01k forks source link

scatter operations missing index locations #1631

Open sarmientoF opened 14 hours ago

sarmientoF commented 14 hours ago

It will be great to return the index locations from the scatter operations like torch_scatter: `from torch_scatter import scatter_max

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6))

out, argmax = scatter_max(src, index, out=out) ` https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html

Is this an upcoming feature ?

awni commented 6 hours ago

There's no plan to add this at the moment. A scatter_argmax might be doable, but I'm curious what would you use it for?