ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
16.62k stars 953 forks source link

Inconsistency in compile with kwargs #825

Open awni opened 6 months ago

awni commented 6 months ago

The fact that some of the following work but some don't seems inconsistent and unexpected. Filling this here mostly so I don't forget about it.

import mlx.core as mx

@mx.compile
def fun(x, y=None):
    if y is not None:
        return x + y
    else:
        return x + 1

fun(mx.array(1.0)) # ok
fun(mx.array(1.0), mx.array(2.0)) # ok
fun(mx.array(1.0), None) # exception
fun(mx.array(1.0), y=None) #exception
romanoneg commented 6 months ago

Second time writing this comment sorry for any notification spam, and let me know if this is the wrong place for it, it seems related enough, but if its the wrong place let me know and I'll delete.

I've run into another edge case with mx.compile and custom dataclasses and Im not super certain why the behavior is occurring:

import mlx.core as mx
from collections import namedtuple

exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=0,y=1)

def foo(mytuple):
    return mytuple[0] + mytuple[1] 

print(foo(mx.array([0,1]))) # works outputs array(1, dtype=int32)
print(foo(example_tuple)) # works outputs 1

compiled_foo = mx.compile(foo)

print(compiled_foo(mx.array([0,1]))) # outputs array(1, dtype=int32)
print(compiled_foo(example_tuple)) # outputs None (?huh?)
awni commented 6 months ago

In the named example type you are not doing any array options so compiling through that doesn't make sense. (The 0 and 1 never get cast to mx.array. You can fix it by doing:

exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=mx.array(0),y=mx.array(1))

def foo(mytuple):
    return mytuple[0] + mytuple[1]

compiled_foo = mx.compile(foo)
print(compiled_foo(example_tuple))

We should have better error messaging (or find a way to support it).