Open awni opened 6 months ago
Second time writing this comment sorry for any notification spam, and let me know if this is the wrong place for it, it seems related enough, but if its the wrong place let me know and I'll delete.
I've run into another edge case with mx.compile
and custom dataclasses and Im not super certain why the behavior is occurring:
import mlx.core as mx
from collections import namedtuple
exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=0,y=1)
def foo(mytuple):
return mytuple[0] + mytuple[1]
print(foo(mx.array([0,1]))) # works outputs array(1, dtype=int32)
print(foo(example_tuple)) # works outputs 1
compiled_foo = mx.compile(foo)
print(compiled_foo(mx.array([0,1]))) # outputs array(1, dtype=int32)
print(compiled_foo(example_tuple)) # outputs None (?huh?)
In the named example type you are not doing any array options so compiling through that doesn't make sense. (The 0
and 1
never get cast to mx.array
. You can fix it by doing:
exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=mx.array(0),y=mx.array(1))
def foo(mytuple):
return mytuple[0] + mytuple[1]
compiled_foo = mx.compile(foo)
print(compiled_foo(example_tuple))
We should have better error messaging (or find a way to support it).
The fact that some of the following work but some don't seems inconsistent and unexpected. Filling this here mostly so I don't forget about it.