Open lkarthee opened 2 weeks ago
Describe the bug
np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.
np.ndarray
To Reproduce
>>> from ml_dtypes import bfloat16 >>> import numpy as np >>> x = np.array(1., dtype=bfloat16) >>> import mlx.core as mx >>> mx.array(x) array(1+0j, dtype=complex64) >>> x = np.array(1) >>> mx.array(x) array(1, dtype=int64) >>> x = np.array(1., dtype=bfloat16) >>> x.dtype dtype(bfloat16) >>> type(x.dtype) <class 'numpy.dtype[bfloat16]'> >>>
Expected behavior Conversion to complex should not happen. Should remain as bfloat16
Desktop (please complete the following information):
Additional context Originally posted by @lkarthee in https://github.com/ml-explore/mlx/issues/1066#issuecomment-2089573368
Describe the bug
np.ndarray
of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.To Reproduce
Expected behavior Conversion to complex should not happen. Should remain as bfloat16
Desktop (please complete the following information):
Additional context Originally posted by @lkarthee in https://github.com/ml-explore/mlx/issues/1066#issuecomment-2089573368