Describe the bug
MLX raises a cryptic error (ValueError: [transpose] Recived 3 axes for array with 4 dimensions.) when a function is vmapped multiple times in a specific way and we index one argument by the other.
See example below as the steps to reproduce are quite specific.
To Reproduce
import mlx.core as mx
def f(args):
(x, y) = args
return x[y]
f = mx.vmap(f, [0, None])
f = mx.vmap(f, [0, 0])
f([mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)])
# ValueError: [transpose] Recived 3 axes for array with 4 dimensions.
3 is the expected number of dimensions returned from f while 4 is the original number of dimensions. Perhaps this gives a clue for where the error originates from? By increasing the number of dimensions in the arrays we get different numbers in the error message:
Describe the bug MLX raises a cryptic error (
ValueError: [transpose] Recived 3 axes for array with 4 dimensions.
) when a function is vmapped multiple times in a specific way and we index one argument by the other.See example below as the steps to reproduce are quite specific.
To Reproduce
3 is the expected number of dimensions returned from f while 4 is the original number of dimensions. Perhaps this gives a clue for where the error originates from? By increasing the number of dimensions in the arrays we get different numbers in the error message:
Expected behavior Expected not an error and expected
f([mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)]).shape
to output(2, 3)
.Desktop (please complete the following information):
Additional context Discovered while trying to implement a custom interpolation method.