For dimensions with integer indices, src should squeeze these dimensions.
For example, an array a = mx.zeros(5, 4, 3) indexed by a[:, 0] should end up with shape (5, 3), and the second dimension is squeezed, as the implementation in mlx_get_item.
However it's a bit complex for mlx_set_item since the array returned by mlx_slice_update should have the original, unsqueezed shape. To achieve the same effect and make it easier, we could expand new axes in the update array on these dimensions that are supposed to be squeezed.
For example, to update a[:, 0] = mx.ones((5, 3)), instead of squeezing a to shape (5, 3), we expand a new axis for the update array, so its shape becomes (5, 1, 3), which is broadcast compatible with a's unsqueezed shape (5, 4, 3). In other words, a[:, 0] = mx.ones((5, 3)) now achieve the same effect as a[:, 0:1] = mx.ones((5, 3))[:, None]
I'm a bit unsure if this is the correct way to implement it, and I don't really understand the old dimension expansion logic
Proposed changes
Fix #1015
For dimensions with integer indices,
src
should squeeze these dimensions. For example, an arraya = mx.zeros(5, 4, 3)
indexed bya[:, 0]
should end up with shape (5, 3), and the second dimension is squeezed, as the implementation inmlx_get_item
.However it's a bit complex for
mlx_set_item
since the array returned bymlx_slice_update
should have the original, unsqueezed shape. To achieve the same effect and make it easier, we could expand new axes in the update array on these dimensions that are supposed to be squeezed. For example, to updatea[:, 0] = mx.ones((5, 3))
, instead of squeezing a to shape (5, 3), we expand a new axis for the update array, so its shape becomes (5, 1, 3), which is broadcast compatible with a's unsqueezed shape (5, 4, 3). In other words,a[:, 0] = mx.ones((5, 3))
now achieve the same effect asa[:, 0:1] = mx.ones((5, 3))[:, None]
I'm a bit unsure if this is the correct way to implement it, and I don't really understand the old dimension expansion logic
maybe @jagrit06 can explain it a bit? I guess this tries to do something similar but can't see how it's done.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes