jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

numpy casting error: Cannot cast array data from dtype(bfloat16) to dtype(bfloat16) according to the rule 'unsafe' #145

Closed idanori closed 5 months ago

idanori commented 5 months ago

python 3.10.6 tensorflow==2.12, jax == 0.4.25 ml-dtypes >= 0.3.2

each of the following casting will fail with TypeError: Cannot cast array data from dtype(bfloat16) to dtype(bfloat16) according to the rule 'unsafe'

import numpy as np
import tensorflow as tf

const = tf.constant([1,2], dtype=tf.bfloat16)
const_numpy = const.numpy()
try:
    const_numpy.astype('bfloat16')
except TypeError as e:
    print(f"astype cast {const_numpy.dtype} to 'bfloat16' error: {e}")

try:
    const_numpy.astype(np.dtype('bfloat16'))
except TypeError as e:
    print(f"astype cast {const_numpy.dtype} to np.dtype('bfloat16') error: {e}")

try:
    np.asarray(const_numpy, 'bfloat16')
except TypeError as e:
    print(f"asarray cast {const_numpy.dtype} to 'bfloat16' error: {e}")

try:
    np.asarray(const_numpy, np.dtype('bfloat16'))
except TypeError as e:
    print(f"asarray cast {const_numpy.dtype} to np.dtype('bfloat16') error: {e}")
jakevdp commented 5 months ago

Hi - this looks like it's unrelated to ml_dtypes. Tensorflow version 2.12 bundled its own definitions of bfloat16 and other custom types; it wasn't until ~version 2.13~ version 2.14 that tensorflow began depending on and using ml_dtypes. I'd suggest updating to a more recent tensorflow release.

idanori commented 5 months ago

tesnorflow 2.12 require jax>=0.3.15 which in turn require ml-dtypes>=0.2.0

from output of pip install tensorflow==2.12: Collecting jax>=0.3.15 (from tensorflow==2.12) Collecting ml-dtypes>=0.2.0 (from jax>=0.3.15->tensorflow==2.12)

jakevdp commented 5 months ago

I'm not sure why tensorflow 2.12 lists jax as a dependency, but regardless it's definitely the case that tensorflow 2.12 defines and registers its own bfloat16 type. I would raise an issue in the tensorflow repository (I suspect they will suggest updating to a more recent version).

idanori commented 5 months ago

If that is the case, how can it be explained that pip install ml-dtypes == 0.3.1 solve the issue ?

jakevdp commented 5 months ago

If you're importing both a newer ml_dtypes and an older tensorflow in the same environment, then having two different bfloat16 declarations can cause issues. ml_dtypes 0.3.1 was released in an era when tensorflow still defined its own bfloat16 type, so it has code to check whether another package has registered bfloat16. We removed that in later versions of ml_dtypes once jax and tensorflow removed their own bfloat16 registrations.

The net result is that if you use an old tensorflow release in the same environment as a new ml_dtypes release, you'll run into compatibility issues like the one you're seeing. The solution is to not import an old tensorflow and a new ml_dtypes in the same environment.

My suggestion would be to update tensorflow, as that's the easiest fix here. Alternatively, you can install the versions of the tensorflow dependencies that were current when v2.12 was released, and you should not see these kinds of compatibility issues.