stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
149 stars 11 forks source link

better stack trace for concretization inside scan #57

Open dlwh opened 8 months ago

dlwh commented 8 months ago

Jax exceptions from scan don't give any indication where the actual error occurred. It would be better if we could catch this somehow and give a better stack trace.

something like:

scan(self.foo, ...)

def foo(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], layer_idx, *, key):
        k1, k2, k3, k4 = haliax.jax_utils.maybe_rng_split(key, 4)

        attn_output = self.attn(self.ln_1(x), mask=mask, layer_idx=layer_idx, key=k1)
        attn_output = self.resid_dropout(attn_output, key=k2)
        x = x + attn_output 

        ff_output = self.mlp(self.ln_2(x), key=k3)
        ff_output = self.resid_dropout(ff_output, key=k4)
        x = x + ff_output

        #import ipdb; ipdb.set_trace()
        if jnp.equal(layer_idx.array, 4):
            #x = x + 0.01*jnp.sin(x*1e2)
            x = x + 0.01*hax.sin(x*1e2)

        return x

produced:


  carry, ys = lax.scan(wrapped_fn, init, leaves, reverse=reverse, unroll=unroll)
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 83, in wrapped_fn
  carry, y = f(carry, *args, **kwargs)
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/hof.py", line 124, in scan_compatible_fn
  return fn(carry, *args, **kwargs), None
 File "/nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/haliax/jax_utils.py", line 69, in wrapper
  dynamic_out, static_out = checkpointed_fun(static, dynamic)
jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function new_fun at /nlp/scr/ahmedah/miniconda3/envs/locked/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:357 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[0][0][3][<flat index 0>].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception: