ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.59k stars 1.02k forks source link

[BUG] Indexing inside double vmap raises 'ValueError: [transpose] Recived 3 axes for array with 4 dimensions.' #1517

Closed magnusdk closed 3 weeks ago

magnusdk commented 1 month ago

Describe the bug MLX raises a cryptic error (ValueError: [transpose] Recived 3 axes for array with 4 dimensions.) when a function is vmapped multiple times in a specific way and we index one argument by the other.

See example below as the steps to reproduce are quite specific.

To Reproduce

import mlx.core as mx

def f(args):
    (x, y) = args
    return x[y]

f = mx.vmap(f, [0, None])
f = mx.vmap(f, [0, 0])
f([mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)])
# ValueError: [transpose] Recived 3 axes for array with 4 dimensions.

3 is the expected number of dimensions returned from f while 4 is the original number of dimensions. Perhaps this gives a clue for where the error originates from? By increasing the number of dimensions in the arrays we get different numbers in the error message:

import mlx.core as mx

def f2(args):
    (x, y) = args
    return x[y]

f2 = mx.vmap(f2, [0, None])
f2 = mx.vmap(f2, [0, 0])
f2([mx.ones((2, 3, 4, 5, 6)), mx.zeros((2, 3, 4), dtype=mx.int32)])
# ValueError: [transpose] Recived 7 axes for array with 8 dimensions.

Expected behavior Expected not an error and expected f([mx.ones((2, 3, 4)), mx.zeros(2, dtype=mx.int32)]).shape to output (2, 3).

Desktop (please complete the following information):

Additional context Discovered while trying to implement a custom interpolation method.

magnusdk commented 3 weeks ago

Thanks for fixing this! :) I can confirm that it works for my usecase now.