jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

jax.debug.breakpoint gives UnexpectedTracerError when used with jax.lax.cond #23555

Open billmark opened 2 weeks ago

billmark commented 2 weeks ago

Description

import jax
import jax.numpy as jnp

def f(x, example):
  jax.lax.cond(example == 1, jax.debug.breakpoint, lambda *args: None)
  return x

f_vmap = jax.vmap(f, in_axes=(0, None), out_axes=0)

def g(x, example):
  return f_vmap(x, example)

x = jnp.arange(4)
example = jnp.array(0, dtype=jnp.int32)
g(x, example)
example = jnp.array(1, dtype=jnp.int32)
g(x, example)

Gives the error

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1, 2, 3], dtype=int32)
  batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

But I believe this code should work, and Jax expert Jake VanderPlas has confirmed that I seem to have uncovered a bug, and asked me to file this github bug. Jake said that he had to turn on JAX_CHECK_TRACER_LEAKS=1 to observe the problem, but that was with a slightly different repro case. I did not have to do this (maybe it's already on in my environment?).

Additional context: The purpose of this code is to enter the debugger on a particular train step, so that I can examine variables inside of f() at that train step.

System info (python version, jaxlib version, accelerator, etc.)

Python: 3.11 JAX: Top of tree inside google as of 3pm Pacific Time on Sept 10, 2024. Accelerator: TPU

mattjj commented 2 weeks ago

Thanks for reporting this, and the clear repro.

Maybe related: #16732

billmark commented 2 weeks ago

For others hitting this bug, I discovered a workaround recommended in someone else's code: The workaround is to wrap the call to jax.debug.breakpoint in a lambda. i.e. Change: jax.lax.cond(example == 1, jax.debug.breakpoint, lambda *args: None) to jax.lax.cond(example == 1, lambda: jax.debug.breakpoint(), lambda *args: None)

As soon as you enter the debugger, use the "up" command to go one frame up in the stack (to escape the lambda), and then you can look at the variables you wanted to look at. (But JAX team, please do fix the bug, so this workaround isn't necessary!).

mattjj commented 2 weeks ago

That repro hangs my machine :grin-sweat: