Closed justinchuby closed 16 hours ago
From my experience implementing #166 , I believe the present float8_base
can be adapted without much difficulty to support FP4 (and FP6) dtypes.
I would follow the suggestion of @cloudhan, adding a type traits sizeof_bits
specialized for FP4/FP6 formats to give the proper bit size of every dtype. Then using sizeof_bits
, TraitsBase
can be extended to have proper bitmasking for the exponent and mantissa parts (https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/include/float8.h#L891). The rest should almost work out of the box :)
@hawkinsp thanks! Do you know when it will be released?
Today, hopefully.
That’s amazing, thank you!
Creating this tread from https://github.com/jax-ml/ml_dtypes/issues/116 for a focused proposal on supporting FP4E2M1. Thanks!