Open colehaus opened 8 months ago
>>> import jax >>> jax.__version__ '0.4.21' >>> from jax.dtypes import bfloat16 >>> f"{bfloat16(1)]}" '1' >>> f"{bfloat16(1):.2f}" Traceback (most recent call last): File "<stdin>", line 1, in <module> ValueError: Unknown format code 'f' for object of type 'str'
This is a little surprising to me and, in a larger context, the error message isn't very suggestive about what or where the error is.
This is a little surprising to me and, in a larger context, the error message isn't very suggestive about what or where the error is.