jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Support unpacked FP4E2M1 #164

Closed justinchuby closed 16 hours ago

justinchuby commented 1 month ago

Creating this tread from https://github.com/jax-ml/ml_dtypes/issues/116 for a focused proposal on supporting FP4E2M1. Thanks!

balancap commented 1 month ago

From my experience implementing #166 , I believe the present float8_base can be adapted without much difficulty to support FP4 (and FP6) dtypes.

I would follow the suggestion of @cloudhan, adding a type traits sizeof_bits specialized for FP4/FP6 formats to give the proper bit size of every dtype. Then using sizeof_bits, TraitsBase can be extended to have proper bitmasking for the exponent and mantissa parts (https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h#L891). The rest should almost work out of the box :)

hawkinsp commented 16 hours ago

https://github.com/jax-ml/ml_dtypes/pull/181 did this!

justinchuby commented 15 hours ago

@hawkinsp thanks! Do you know when it will be released?

hawkinsp commented 15 hours ago

Today, hopefully.

justinchuby commented 14 hours ago

That’s amazing, thank you!