ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] give better diagnostic message when calling compiled code with an eval in it -- currently "Attempting to eval an array without a primitive" #1083

Closed davidkoski closed 1 week ago

davidkoski commented 1 week ago

Describe the bug

From https://github.com/ml-explore/mlx-swift/issues/82 -- compiled code produces

ValueError: [eval] Attempting to eval an array without a primitive

To Reproduce

import mlx.core as mx

def foo(x):
    y = mx.softmax(x, -1)
    logits = y[0, 0, 0].item()
    return mx.full([1, 100], mx.array(logits))

compiled = mx.compile(foo)

# without compile it works
# compiled = foo

for _ in range(10):
    x = mx.random.randint(0, 10, [1, 32, 50000]).astype(mx.float32)
    y = compiled(x)
    value = y[0, -1].item()
    print(value)

produces:

ValueError: [eval] Attempting to eval an array without a primitive

If run without the compile it gives a list of random numbers as expected.

Expected behavior

Compiled version should run the same way the non compiled code runs.

Desktop (please complete the following information):

awni commented 1 week ago

This is expected behavior. You can't compile through an eval.

I think we need to update the error message so it's actually useful.

davidkoski commented 1 week ago

How would you change the code to make it work? Eventually you have to eval the result if you want it.

davidkoski commented 1 week ago

ooooh, I understand:

    logits = y[0, 0, 0].item()

That code inside the compiled method is an eval -- you can't run an eval inside the compiled code.

awni commented 1 week ago

This case is easy because you can (and should) just pass y[0,0,0] to full directly. It will also be faster since you don't have to come back to user land / make extra arrays even without a compile.

import mlx.core as mx

def foo(x):
    y = mx.softmax(x, -1)
    return mx.full([1, 100], y[0, 0, 0])

compiled = mx.compile(foo)

# without compile it works
# compiled = foo

for _ in range(10):
    x = mx.random.randint(0, 10, [1, 32, 50000]).astype(mx.float32)
    y = compiled(x)
    value = y[0, -1].item()
    print(value)
davidkoski commented 1 week ago

And the fix would be probably to not call item() -- it is just being used to create a new array on the line below.

davidkoski commented 1 week ago
def foo(x):
    y = mx.softmax(x, -1)
    logits = y[0, 0, 0]
    return mx.full([1, 100], logits)
davidkoski commented 1 week ago

Thanks @awni ! Do you want to repurpose this issue for the better error message?

awni commented 1 week ago

Yes, we can leave it open for that. Its long since time to fix that error message.