jax-ml / ml_dtypes

A stand-alone implementation of several NumPy dtype extensions used in machine learning.
Apache License 2.0
191 stars 25 forks source link

bfloat16s (at least) throw error when given a format specifiier #134

Open colehaus opened 8 months ago

colehaus commented 8 months ago
>>> import jax
>>> jax.__version__
'0.4.21'
>>> from jax.dtypes import bfloat16
>>> f"{bfloat16(1)]}"
'1'
>>> f"{bfloat16(1):.2f}"
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: Unknown format code 'f' for object of type 'str'

This is a little surprising to me and, in a larger context, the error message isn't very suggestive about what or where the error is.