Closed martiningram closed 11 months ago
Great question. I don't think we have a solution in place for this right now, but I think we can make one.
There are at least two things to solve here:
np.seterr(invalid="raise")
Thanks Matt, great that you think it's worthwhile to enhance this!
Although it's cool you have a clear roadmap, I am actually really blocked by this at the moment and was wondering if there are any things I could do in the meantime? I'd be happy to dig into some of the backend if required. I've already changed to float64
which has helped but not resolved things.
Thanks for letting us know. Any chance you can share a small repro? I just want to make sure we provide the right pointers or tools.
Blocking a user is the worst feeling! That's a magic word to get us to help you out ASAP :)
I added some basic nan debugging machinery in #482. As with other config options there are a few ways to turn it on:
JAX_DEBUG_NANS
environment variable to something truthy,from jax.config import config
and config.update("jax_debug_nans", True)
near the top of your main file,from jax.config import config
and config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like --jax_debug_nans=True
.Switching that option on adds a nan check to every floating point type value produced by XLA. That means values are pulled back to the host and checked as ndarrays for every primitive operation not under an @jit
. For code under an @jit
, the output of every @jit
function is checked and if a nan is present it will re-run the function in de-optimized op-by-op mode (effectively removing one level of @jit
at a time).
There could be tricky situations that arise, like nans that only occur under a @jit
but don't get produced in de-optimized mode. In that case you'll see a warning message print out but your code will continue to execute, so we can dig in deeper.
If the nans are being produced in the backward pass of a gradient evaluation, when an exception is raised several frames up in the stack trace you'll be in the backward_pass
function, which is essentially a simple jaxpr interpreter that walks the sequence of primitive operations in reverse. If it's not immediately obvious, you can poke around a bit to find the primitive that's producing the nan by doing things in an interactive debugger like p eqn.primitive
in that stack frame.
How does that sound? This is a good opportunity to add exactly the tooling we want; JAX is tiny and easy to instrument, so there's no reason not to get this right.
Oh wow, thanks so much! Can't wait to try this out. So sorry, I didn't mean to make you feel bad! It's very likely this is some stupid mistake on my end, but I am very grateful that you added tooling so quickly!
Hah don't worry, I was joking about feeling bad. But it does light a fire under us whenever someone is blocked :)
Check out this comment for a toy example of how to use this. It's going to be a bit hairier to debug nans in the backward pass, but hopefully not too bad.
When you get back to it, let us know how it goes, and any additional issues you run into. This kind of feedback is incredibly helpful, and it's going to pay off a lot in the future when it helps us build an awesome debugging experience.
Hi Matt,
Thanks for this! Indeed it now crashes rather than returning nan
, which is great. From the stack trace below, it looks like the mul
operation raises the issue:
/Users/ingramm/Projects/software/jax/jax/lib/xla_bridge.py:128: UserWarning: No GPU found, falling back to CPU.
warnings.warn('No GPU found, falling back to CPU.')
/Users/ingramm/Projects/software/jax/jax/numpy/linalg.py:51: UserWarning: numpy.linalg support is experimental and may cause silent failures or wrong outputs
warnings.warn(_EXPERIMENTAL_WARNING)
Log posterior is Traced<ConcreteArray(-7522.515563681398)>with<JVPTrace(level=1/0)>.
Log determinant is Traced<ConcreteArray(247.07115427818476)>with<JVPTrace(level=1/0)>
Traceback (most recent call last):
File "nan_gradient.py", line 179, in <module>
data['l'], data['b'], data['n_c']))
File "/Users/ingramm/Projects/software/jax/jax/api.py", line 206, in grad_f
ans, g = value_and_grad_f(*args, **kwargs)
File "/Users/ingramm/Projects/software/jax/jax/api.py", line 243, in value_and_grad_f
g = vjp_py(onp.ones((), onp.result_type(ans)))
File "/Users/ingramm/Projects/software/jax/jax/api_util.py", line 56, in apply_jaxtree_fun
ans = fun(*args)
File "/Users/ingramm/Projects/software/jax/jax/api.py", line 570, in out_vjp_packed
return out_vjp(cotangent_in)
File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 81, in vjp_
_, arg_cts = backward_pass(jaxpr, consts, (), dummy_args, dummy_primal_and_ct)
File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 139, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(ct_in, *invals, **eqn.params)
File "/Users/ingramm/Projects/software/jax/jax/interpreters/ad.py", line 338, in bilinear_transpose
out = zero if cotangent is zero else lhs_rule(cotangent, y, **kwargs)
File "/Users/ingramm/Projects/software/jax/jax/lax.py", line 242, in mul
return mul_p.bind(x, y)
File "/Users/ingramm/Projects/software/jax/jax/core.py", line 75, in bind
return self.impl(*args, **kwargs)
File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 54, in apply_primitive
return compiled_fun(*args)
File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 85, in execute_compiled_primitive
return result_handler(compiled.Execute(input_bufs, not core.skip_checks))
File "/Users/ingramm/Projects/software/jax/jax/interpreters/xla.py", line 105, in handle_result
raise FloatingPointError("invalid value")
FloatingPointError: invalid value
I've put up a little reproducible example here: error.zip I hope I included all the relevant files, but please let me know if there are any problems with it!
The error happens when I try to calculate the gradient of the log determinant of the negative Hessian. This seems to run fine in autograd. Would be really grateful for any ideas.
Woohoo, the nans were caught just like we wanted! Now we've got this bug on the run.
Thanks for the repro. My guess is that this is coming from multiplying 0 * inf
. Maybe there's a canonical value we'd choose here (like 0
if the covector being pulled back is 0
, no matter what the Jacobian is). But that might be missing the larger issue, namely that things might be becoming ill-conditioned.
Is the matrix for which you're computing the log determinant becoming very ill-conditioned, or even indefinite? It could be that JAX's linalg has slightly different numerics than Autograd's, and maybe we could improve the stability of some of our Jacobian calculations.
I relabled the issue as a bug
because now we're trying to figure out why JAX eventually produces NaNs here where Autograd might not. I suspect it's a question of numerical stability.
I can't look at your code right away, but I plan to get to it later.
Thanks Matt! I haven't read as much about numerical linear algebra as I would like to yet, but it looks like the largest eigenvalue of the matrix is about 5.2E5 and the smallest is 1.76E0, which I guess means we have a condition number of about 3E5 (?). Does that sound problematic? There don't seem to be any obvious issues computing the cholesky etc., I don't run into errors there.
That doesn't sound bad at all, no. Hmm...
(By the way, totally coincidentally I'm flying to Melbourne a week from today.)
Oh awesome, we should meet up if you have any time to spare!
Hi there, just wondering if there might be an update on this? No rush at all. I've tried some other things like changing the determinant by rearranging the equations, as well as using different Cholesky decompositions to calculate the determinant, but have not had any luck so far. No problem though if there are more pressing things / there's no time to look at this right now!
@jekbradbury That's what we do with lax._safe_mul
, as in 58749c0. It could be that we need to use it in some more JVPs.
Hi @mattjj , I'm having the same issue, but it's a bit of a different use case. In my case, I think the NaNs are happening because there is an inf * 0
happening somewhere. I'd like to define that be the value of zero.
The context is that I am doing learning in an HMM. I wrote the forward pass to compute a log-normalizer, and I'm using grad
to compute expectations. It's actually quite similar to this gist, but I have Gaussian observations. At certain points during EM, the value of the log-normalizer will compute just fine, but the gradients will have NaNs in them.
Here's my log-normalizer function:
def _log_normalizer(log_A, log_likelihoods):
A, likes = map(np.exp, (log_A, log_likelihoods))
N, K = likes.shape
with loops.Scope() as s:
s.alpha_p = np.ones(K)
s.log_prob = 0.0
for t in s.range(N):
alpha_c = s.alpha_p * likes[t]
Zt = np.sum(alpha_c)
s.log_prob += np.log(Zt)
alpha_c /= Zt
s.alpha_p = alpha_c @ A
return s.log_prob
Do you have a recommendation here? Maybe if I masked out the elements of log_likelihoods
that are -inf
jax will know to ignore them in the gradient?
FWIW, I am also experiencing NaNs in a similar program that uses loops.Scope()
which happens during the backward pass on the while
primitive. I have been trying to debug the state, but it seems a bit hard to understand what the buffers do. Any tips on how to handle this?
Thanks!
@mattjj Is there a way to debug NaNs in complex primitives like while
-loops? As far as I understand the NaN checks happen somewhere outside the loop so the body is opaque. Would it make sense to allow compiling the while-loop as a series of xla_call to the body function (with checks) for debugging purposes in the JAX XLA interpreter? E.g., replacing
while_loop(cond_fn, body_fn, coll)
with
xla_call(body_fn, coll[0]) ... xla_call(body_fn, coll[n - 1])
etc.?
I am willing to help looking into the implementation myself, but I need a bit of guidance of what can be done.
Thanks in advance!
Nevermind, I have written a high-level Jaxpr interpreter here (based on the documentation): https://github.com/aleatory-science/jaxinterp . I think that will help me debug :)
@mattjj I just wanted to a huge thank you for this feature.
This just got me unstuck after banging my head against the wall for week.
Quick note for people (like me) looking for a more interactive debugging solution.
Enabling jax_debug_nans
will throw an error message like invalid value (nan) encountered in jit(mul)
, pointing to a line in your code such as cond_true_val = 1/beta * jax.numpy.log(1 + jax.numpy.exp(beta * x))
.
Figuring out which particular call in that line caused the issue and why, as well as the values of the involved variables (which of them is problematic, and why?), can be non-trivial. This is especially true if the error occurs in a late iteration of a loop, and simply jax.debug.print
ing every iteration is not an option.
In this case, what you probably want is described here: printing/breakpointing conditional on some expression being non-finite. Works for dynamic jax tracer arrays inside jitted functions, etc. (and without any modifications to the jax config).
I just thought I'd leave this here since it took me a few hours to find this solution.
@e-pet thanks!
Do you have any ideas for how we can make this runtime value debugging docs page more easily discoverable?
I think this issue is actually solved by that docs page and its subpages, so I'm going to close it, though let me know if you disagree. We can probably continue to make nan debugging better (we have some ideas), but I think that docs page covers the best tips (and we'll be sure to update it as we improve tools).
@mattjj I fully agree, the docs pages are excellent and cover this topic very well!
Regarding discoverability:
Finally, and just in case it's helpful in any way for your efforts in further improving the debugging workflow, here's an issue that I haven't yet figured out how to debug effectively: I have the situation noted in my previous comment, i.e., enabling jax_debug_nans
throws an error in iteration ~50k of some optimization loop after 30-60mins of runtime. I see the line where it occurs, but I would really need to look at the variables and step up/down the call stack to understand what's causing it. However, I can for the life of me not figure out a way to make breakpoint_if_nonfinite
trigger on this issue (it occurs deep in a complex operation), and so I am unable to get to an interactive debugger on this error. An option (or simple pattern) to make jax_debug_nans
automatically open up the jax debugger at the position the error is raised would be amazing, but I have no idea whether that is technically feasible at all. (Total jax novice here, as is probably obvious.)
Hi there,
I am running an optimisation using gradients from Jax, and everything goes well for a number of steps until the gradients returned are all
nan
. I am having a bit of a hard time tracking down where the problem is; the forward calculations all seem to be fine.Is there some way I can work out which operation is causing the nans from
grad
? This would be really useful.Thanks!