jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

Simplify dtype registration logic #132

Closed jakevdp closed 8 months ago

jakevdp commented 8 months ago

The checks for already registered types were added in the precursor to ml_dtypes, when both JAX and tensorflow registered their own copies of bfloat16. Since that is no longer a concern, we can remove this logic.

jakevdp commented 8 months ago

There's still another followup here, to allocate dtypes statically rather than on the heap. I plan to do that in a later PR.