jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Refactor FloatPyCast for improved performance using lookup table #151

Open vxst opened 3 months ago

vxst commented 3 months ago

Since there are only 256 values in float8 (whichever flavor), we can utilize a technique similar to the SBox in AES: pre-calculate the mapping for conversion and look it up when performing conversions.

As this project is for python, and (perhaps?) float8.h won't directly used by users, and if the user need to use float8.h, they might want more control(e.g. debug), so I implement the lookup table in FloatPyCast function.

Ref. #150

google-cla[bot] commented 3 months ago

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

vxst commented 3 months ago

CLA added