Open billmark opened 2 weeks ago
Thanks for reporting this, and the clear repro.
Maybe related: #16732
For others hitting this bug, I discovered a workaround recommended in someone else's code:
The workaround is to wrap the call to jax.debug.breakpoint in a lambda. i.e. Change:
jax.lax.cond(example == 1, jax.debug.breakpoint, lambda *args: None)
to
jax.lax.cond(example == 1, lambda: jax.debug.breakpoint(), lambda *args: None)
As soon as you enter the debugger, use the "up" command to go one frame up in the stack (to escape the lambda), and then you can look at the variables you wanted to look at. (But JAX team, please do fix the bug, so this workaround isn't necessary!).
That repro hangs my machine :grin-sweat:
Description
Gives the error
But I believe this code should work, and Jax expert Jake VanderPlas has confirmed that I seem to have uncovered a bug, and asked me to file this github bug. Jake said that he had to turn on JAX_CHECK_TRACER_LEAKS=1 to observe the problem, but that was with a slightly different repro case. I did not have to do this (maybe it's already on in my environment?).
Additional context: The purpose of this code is to enter the debugger on a particular train step, so that I can examine variables inside of f() at that train step.
System info (python version, jaxlib version, accelerator, etc.)
Python: 3.11 JAX: Top of tree inside google as of 3pm Pacific Time on Sept 10, 2024. Accelerator: TPU