ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[Fix] expand axes for dimension with integer indices in mlx_slice_update #1035

Closed PRESIDENT810 closed 2 weeks ago

PRESIDENT810 commented 3 weeks ago

Proposed changes

Fix #1015

For dimensions with integer indices, src should squeeze these dimensions. For example, an array a = mx.zeros(5, 4, 3) indexed by a[:, 0] should end up with shape (5, 3), and the second dimension is squeezed, as the implementation in mlx_get_item.

However it's a bit complex for mlx_set_item since the array returned by mlx_slice_update should have the original, unsqueezed shape. To achieve the same effect and make it easier, we could expand new axes in the update array on these dimensions that are supposed to be squeezed. For example, to update a[:, 0] = mx.ones((5, 3)), instead of squeezing a to shape (5, 3), we expand a new axis for the update array, so its shape becomes (5, 1, 3), which is broadcast compatible with a's unsqueezed shape (5, 4, 3). In other words, a[:, 0] = mx.ones((5, 3)) now achieve the same effect as a[:, 0:1] = mx.ones((5, 3))[:, None]

I'm a bit unsure if this is the correct way to implement it, and I don't really understand the old dimension expansion logic

if (src.ndim() - ax < up.ndim()) {
    upd_expand_dims.push_back(ax - src.ndim());
}

maybe @jagrit06 can explain it a bit? I guess this tries to do something similar but can't see how it's done.

Checklist

Put an x in the boxes that apply.

awni commented 2 weeks ago

Closes #1050