Closed mjbaldwin closed 1 week ago
This error appears to be caused by the following line inside debugging.debug_print
:
formatter.format(fmt, *args, **kwargs)
which tries to perform formatting at trace time, when the args are tracers and don't yet have concrete values.
Description
jax.debug.print()
supports printing variable values, e.g.:a = 5; jax.debug.print("{}", a)
outputs5
. And this works inside of a function that has been JAX compiled as well.And when called without any compiling, it supports formatting values in the standard way, e.g.
a = 5; jax.debug.print("{:.2f}", a)
outputs5.00
.But if you attempt to include any value formatting in a compiled function, then running the function produces an error:
Yet the documentation for jax.debug.print clearly states:
I've encountered this bug because I'm using
jax.debug.print()
to monitor several values at a time during execution to debug it, and I'd like the rows of values to be of a similar length to make the columns more readable. Dealing with floats that might print10.119900703430176
or might print2
makes for a debugging experience that is harder to read.Here is a Colab notebook with example code that you can see the results and run yourself.
System info (python version, jaxlib version, accelerator, etc.)
Running on current Google Colab.