ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] broadcast of scalar array in last dimension fails after #1035 #1052

Closed davidkoski closed 2 weeks ago

davidkoski commented 2 weeks ago

Describe the bug broadcast of scalar array in last dimension fails after #1035

To Reproduce

Include code snippet

>>> import mlx.core as mx
>>> a = mx.zeros([2, 3, 4, 5, 3])
>>> a[..., 0] = 1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: [expand_dims] Invalid axes 4 for output array with 1 dimensions.

Expected behavior

This should be able to broadcast the scalar array into the last dimension.

Desktop (please complete the following information):