ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] arithmetic operations with numpy arrays are not commutative #1066

Closed lkarthee closed 1 week ago

lkarthee commented 2 weeks ago

Describe the bug mx array being first operand in arithmetic op is being converted complex numbers

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx

mx.array(1) + np.array(2)
x = mx.array(1)
y = np.array(2)
x + y 
# >>> array(3+0j, dtype=complex64)
x * y
# >>> array(2+0j, dtype=complex64)
y + x
# >>> 3
type(y + x)
# >>> <class 'numpy.int64'>
type(x + y)
# >>> <class 'mlx.core.array'>

Expected behavior Conversion to complex should not happen. Arithmetic ops should be commutative (ignoring type like ndarray or mx).

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

awni commented 2 weeks ago

That looks like a bug 🤔

lkarthee commented 2 weeks ago

I am not sure if this is related - I am adding it here. Let me know if its not related, i will log another issue.

np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> x = np.array(1., dtype=bfloat16)
>>> import mlx.core as mx
>>> mx.array(x)
array(1+0j, dtype=complex64)
>>> x = np.array(1)
>>> mx.array(x)
array(1, dtype=int64)
>>> x = np.array(1., dtype=bfloat16)
>>> x.dtype
dtype(bfloat16)
>>> type(x.dtype)
<class 'numpy.dtype[bfloat16]'>
>>>
awni commented 2 weeks ago

The bfloat thing is a different issue. I will send a fix for the add shortly. Could put the bfloat problem in a separate issue as it might be harder to fix / require some changes in Nanobind (still investigating).