jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
206 stars 28 forks source link

Fix casts from f8e5m2fnuz to half. #146

Closed copybara-service[bot] closed 6 months ago

copybara-service[bot] commented 6 months ago

Fix casts from f8e5m2fnuz to half.

The current implementation doesn't handle this case correctly, because exponent_shift is negative here.

The test is broken because it just tests the cast via float against itself.