jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

TBox #150

Open vxst opened 3 months ago

vxst commented 3 months ago

Since there are only 256 values in float8 (whichever flavor), we can utilize a technique similar to the SBox in AES: pre-calculate the mapping for conversion and look it up when performing conversions.

Since conversion is actually one of the most frequently used operations in float8, this method will greatly improve performance for this library. It can be used for conversion between float8 and for conversion from float8 to float16/32/64.

I plan to implement it based on the current ConvertImpl, with a new struct ConvertTable. It utilizes ConvertImpl to calculate the mapping, so the behavior will be exactly the same, just much faster. I plan to build the table at the RegisterTwoWayCustomCast stage and utilize it when the source of the two-way cast is 8 bits or less.

Is there anything I need to pay attention to, or do you have any advice (e.g., regarding naming)? I'm starting to implement it and will make a PR when it's finished.

vxst commented 3 months ago

I need your advice on whether to also use ConvertTable for bfloat16 to float8. A 64 KByte lookup table is still relatively small, and building the table is negligible compared to loading a library in Python. I'll implement ConvertTable in a way that can be easily adapted to a 16-bit source format. Perhaps we can discuss it in the PR when I can present more data, such as the performance improvements, to make a better decision on this call.