ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] `np.ndarray` of bfloat16 using ml_dtypes is being interpreted as complex64 #1075

Open lkarthee opened 2 weeks ago

lkarthee commented 2 weeks ago

Describe the bug

np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.

To Reproduce

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> x = np.array(1., dtype=bfloat16)
>>> import mlx.core as mx
>>> mx.array(x)
array(1+0j, dtype=complex64)
>>> x = np.array(1)
>>> mx.array(x)
array(1, dtype=int64)
>>> x = np.array(1., dtype=bfloat16)
>>> x.dtype
dtype(bfloat16)
>>> type(x.dtype)
<class 'numpy.dtype[bfloat16]'>
>>>

Expected behavior Conversion to complex should not happen. Should remain as bfloat16

Desktop (please complete the following information):

Additional context Originally posted by @lkarthee in https://github.com/ml-explore/mlx/issues/1066#issuecomment-2089573368