jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Tips for debugging NaNs in gradient? #475

Closed martiningram closed 11 months ago

martiningram commented 5 years ago

Hi there,

I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are all nan. I am having a bit of a hard time tracking down where the problem is; the forward calculations all seem to be fine.

Is there some way I can work out which operation is causing the nans from grad? This would be really useful.

Thanks!

mattjj commented 5 years ago

Great question. I don't think we have a solution in place for this right now, but I think we can make one.

There are at least two things to solve here:

  1. set up the equivalent of np.seterr(invalid="raise")
  2. catch nans on the backward pass, and associate them helpfully with user code
martiningram commented 5 years ago

Thanks Matt, great that you think it's worthwhile to enhance this!

Although it's cool you have a clear roadmap, I am actually really blocked by this at the moment and was wondering if there are any things I could do in the meantime? I'd be happy to dig into some of the backend if required. I've already changed to float64 which has helped but not resolved things.

mattjj commented 5 years ago

Thanks for letting us know. Any chance you can share a small repro? I just want to make sure we provide the right pointers or tools.

mattjj commented 5 years ago

Blocking a user is the worst feeling! That's a magic word to get us to help you out ASAP :)

I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:

  1. you can set the JAX_DEBUG_NANS environment variable to something truthy,
  2. you can add from jax.config import config and config.update("jax_debug_nans", True) near the top of your main file,
  3. you can add from jax.config import config and config.parse_flags_with_absl() to your main file, then set the option using a command-line flag like --jax_debug_nans=True.

Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit. For code under an @jit, the output of every @jit function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of @jit at a time).

There could be tricky situations that arise, like nans that only occur under a @jit but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute, so we can dig in deeper.

If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you'll be in the backward_pass function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it's not immediately obvious, you can poke around a bit to find the primitive that's producing the nan by doing things in an interactive debugger like p eqn.primitive in that stack frame.

How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there's no reason not to get this right.

martiningram commented 5 years ago

Oh wow, thanks so much! Can't wait to try this out. So sorry, I didn't mean to make you feel bad! It's very likely this is some stupid mistake on my end, but I am very grateful that you added tooling so quickly!

mattjj commented 5 years ago

Hah don't worry, I was joking about feeling bad. But it does light a fire under us whenever someone is blocked :)

Check out this comment for a toy example of how to use this. It's going to be a bit hairier to debug nans in the backward pass, but hopefully not too bad.

When you get back to it, let us know how it goes, and any additional issues you run into. This kind of feedback is incredibly helpful, and it's going to pay off a lot in the future when it helps us build an awesome debugging experience.

martiningram commented 5 years ago

Hi Matt,

Thanks for this! Indeed it now crashes rather than returning nan, which is great. From the stack trace below, it looks like the mul operation raises the issue:

/Users/ingramm/Projects/software/jax/jax/lib/xla_bridge.py:128: UserWarning: No GPU found, falling back to CPU.
  warnings.warn('No GPU found, falling back to CPU.')
/Users/ingramm/Projects/software/jax/jax/numpy/linalg.py:51: UserWarning: numpy.linalg support is experimental and may cause silent failures or wrong outputs
  warnings.warn(_EXPERIMENTAL_WARNING)
Log posterior is Traced<ConcreteArray(-7522.515563681398)>with<JVPTrace(level=1/0)>.
Log determinant is Traced<ConcreteArray(247.07115427818476)>with<JVPTrace(level=1/0)>
Traceback (most recent call last):
  File "nan_gradient.py", line 179, in <module>
    data['l'], data['b'], data['n_c']))
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 206, in grad_f
    ans, g = value_and_grad_f(*args, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 243, in value_and_grad_f
    g = vjp_py(onp.ones((), onp.result_type(ans)))
  File "/Users/ingramm/Projects/software/jax/jax/api_util.py", line 56, in apply_jaxtree_fun
    ans = fun(*args)
  File "/Users/ingramm/Projects/software/jax/jax/api.py", line 570, in out_vjp_packed
    return out_vjp(cotangent_in)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 81, in vjp_
    _, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 139, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 338, in bilinear_transpose
    out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/lax.py", line 242, in mul
    return mul_p.bind(x, y)
  File "/Users/ingramm/Projects/software/jax/jax/core.py", line 75, in bind
    return self.impl(*args, **kwargs)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 54, in apply_primitive
    return compiled_fun(*args)
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 85, in execute_compiled_primitive
    return result_handler(compiled.Execute(input_bufs, not core.skip_checks))
  File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 105, in handle_result
    raise FloatingPointError("invalid value")
FloatingPointError: invalid value

I've put up a little reproducible example here: error.zip I hope I included all the relevant files, but please let me know if there are any problems with it!

The error happens when I try to calculate the gradient of the log determinant of the negative Hessian. This seems to run fine in autograd. Would be really grateful for any ideas.

mattjj commented 5 years ago

Woohoo, the nans were caught just like we wanted! Now we've got this bug on the run.

Thanks for the repro. My guess is that this is coming from multiplying 0 * inf. Maybe there's a canonical value we'd choose here (like 0 if the covector being pulled back is 0, no matter what the Jacobian is). But that might be missing the larger issue, namely that things might be becoming ill-conditioned.

Is the matrix for which you're computing the log determinant becoming very ill-conditioned, or even indefinite? It could be that JAX's linalg has slightly different numerics than Autograd's, and maybe we could improve the stability of some of our Jacobian calculations.

mattjj commented 5 years ago

I relabled the issue as a bug because now we're trying to figure out why JAX eventually produces NaNs here where Autograd might not. I suspect it's a question of numerical stability.

I can't look at your code right away, but I plan to get to it later.

martiningram commented 5 years ago

Thanks Matt! I haven't read as much about numerical linear algebra as I would like to yet, but it looks like the largest eigenvalue of the matrix is about 5.2E5 and the smallest is 1.76E0, which I guess means we have a condition number of about 3E5 (?). Does that sound problematic? There don't seem to be any obvious issues computing the cholesky etc., I don't run into errors there.

mattjj commented 5 years ago

That doesn't sound bad at all, no. Hmm...

mattjj commented 5 years ago

(By the way, totally coincidentally I'm flying to Melbourne a week from today.)

martiningram commented 5 years ago

Oh awesome, we should meet up if you have any time to spare!

martiningram commented 5 years ago

Hi there, just wondering if there might be an update on this? No rush at all. I've tried some other things like changing the determinant by rearranging the equations, as well as using different Cholesky decompositions to calculate the determinant, but have not had any luck so far. No problem though if there are more pressing things / there's no time to look at this right now!

jekbradbury commented 5 years ago

TensorFlow recently made a couple changes to perform all gradient multiplication (products of each Jacobian-transpose and seed) in ops where J-transpose could be infinity using a special multiplication op where 0 * inf is 0. I wonder if that might be the way to go here.

mattjj commented 5 years ago

@jekbradbury That's what we do with lax._safe_mul, as in 58749c0. It could be that we need to use it in some more JVPs.

bantin commented 4 years ago

Hi @mattjj , I'm having the same issue, but it's a bit of a different use case. In my case, I think the NaNs are happening because there is an inf * 0 happening somewhere. I'd like to define that be the value of zero.

The context is that I am doing learning in an HMM. I wrote the forward pass to compute a log-normalizer, and I'm using grad to compute expectations. It's actually quite similar to this gist, but I have Gaussian observations. At certain points during EM, the value of the log-normalizer will compute just fine, but the gradients will have NaNs in them.

Here's my log-normalizer function:

def _log_normalizer(log_A, log_likelihoods):
  A, likes = map(np.exp, (log_A, log_likelihoods))
  N, K = likes.shape

  with loops.Scope() as s:
    s.alpha_p = np.ones(K)
    s.log_prob = 0.0
    for t in s.range(N):

      alpha_c = s.alpha_p * likes[t]

      Zt = np.sum(alpha_c)
      s.log_prob += np.log(Zt)
      alpha_c /= Zt 
      s.alpha_p = alpha_c @ A

    return s.log_prob

Do you have a recommendation here? Maybe if I masked out the elements of log_likelihoods that are -inf jax will know to ignore them in the gradient?

ahmadsalim commented 4 years ago

FWIW, I am also experiencing NaNs in a similar program that uses loops.Scope() which happens during the backward pass on the while primitive. I have been trying to debug the state, but it seems a bit hard to understand what the buffers do. Any tips on how to handle this?

Thanks!

ahmadsalim commented 4 years ago

@mattjj Is there a way to debug NaNs in complex primitives like while-loops? As far as I understand the NaN checks happen somewhere outside the loop so the body is opaque. Would it make sense to allow compiling the while-loop as a series of xla_call to the body function (with checks) for debugging purposes in the JAX XLA interpreter? E.g., replacing

while_loop(cond_fn, body_fn, coll) with xla_call(body_fn, coll[0]) ... xla_call(body_fn, coll[n - 1]) etc.?

I am willing to help looking into the implementation myself, but I need a bit of guidance of what can be done.

Thanks in advance!

ahmadsalim commented 4 years ago

Nevermind, I have written a high-level Jaxpr interpreter here (based on the documentation): https://github.com/aleatory-science/jaxinterp . I think that will help me debug :)

KristianHolsheimer commented 4 years ago

@mattjj I just wanted to a huge thank you for this feature.

This just got me unstuck after banging my head against the wall for week.

e-pet commented 11 months ago

Quick note for people (like me) looking for a more interactive debugging solution.

Enabling jax_debug_nans will throw an error message like invalid value (nan) encountered in jit(mul), pointing to a line in your code such as cond_true_val = 1/beta * jax.numpy.log(1 + jax.numpy.exp(beta * x)).

Figuring out which particular call in that line caused the issue and why, as well as the values of the involved variables (which of them is problematic, and why?), can be non-trivial. This is especially true if the error occurs in a late iteration of a loop, and simply jax.debug.printing every iteration is not an option.

In this case, what you probably want is described here: printing/breakpointing conditional on some expression being non-finite. Works for dynamic jax tracer arrays inside jitted functions, etc. (and without any modifications to the jax config).

I just thought I'd leave this here since it took me a few hours to find this solution.

mattjj commented 11 months ago

@e-pet thanks!

Do you have any ideas for how we can make this runtime value debugging docs page more easily discoverable?

I think this issue is actually solved by that docs page and its subpages, so I'm going to close it, though let me know if you disagree. We can probably continue to make nan debugging better (we have some ideas), but I think that docs page covers the best tips (and we'll be sure to update it as we improve tools).

e-pet commented 11 months ago

@mattjj I fully agree, the docs pages are excellent and cover this topic very well!

Regarding discoverability:

Finally, and just in case it's helpful in any way for your efforts in further improving the debugging workflow, here's an issue that I haven't yet figured out how to debug effectively: I have the situation noted in my previous comment, i.e., enabling jax_debug_nans throws an error in iteration ~50k of some optimization loop after 30-60mins of runtime. I see the line where it occurs, but I would really need to look at the variables and step up/down the call stack to understand what's causing it. However, I can for the life of me not figure out a way to make breakpoint_if_nonfinite trigger on this issue (it occurs deep in a complex operation), and so I am unable to get to an interactive debugger on this error. An option (or simple pattern) to make jax_debug_nans automatically open up the jax debugger at the position the error is raised would be amazing, but I have no idea whether that is technically feasible at all. (Total jax novice here, as is probably obvious.)