Open j-towns opened 6 years ago
Hey @j-towns is there any development in this matter? I'm facing a similar problem and looking for solution. In my case forward pass works great, but I get NaN gradients (divide by zero) during backward pass.
Hi there, sorry for the slow reply. The short answer is that there hasn't been any concrete progress on this. The situation is slightly better in JAX, see here for a tutorial on how to debug backward pass NaNs in JAX.
One thing that is a very common cause of backward pass NaNs is functions like numpy.where
and numpy.select
. If there are NaNs in any of the inputs to those functions, even if they are not selected and therefore ignored in the outputs, you will still get NaNs in the backward pass. This is a basic issue that exists in some form in all array-based autodiff libraries. There's a detailed description here. The easiest workaround is known as the 'double where trick', example here.
Thanks for your reply. I later found out it was because of Frobenius Norm of zero-vector (which apparently is a problem in all autodiff libraries). I solved that by adding a small epsilon to the term. however, having a permanent fix like this https://github.com/pytorch/pytorch/pull/2775 would be very helpful. (I can send a PR if you think this will be a good addition).
I can send a PR if you think this will be a good addition
Yes, that would be a good addition. If you’re willing to submit a pr I’ll merge it. It seems we discussed this and agreed it was a good idea 5 years ago but never got round to implementing it.
I'm getting underflow during the computation of the reverse pass of an RNN gradient. I have numpy raise an exception when underflow is detected. However, it's very difficult (or maybe impossible) to tell from this exception's traceback whereabouts (in the forward pass) the nodes which caused the error were created.
Ideally, I'd like to be able to start my debugger (or any interpreter) from the forward pass frame in which the node was created which caused the issue. That way I could look around and find out, for example, in which iteration of the RNN the node was created.
I've come up with a crude partial solution, which is to have nodes cache the
frame
in which they were created:Now when the exception is raised during the reverse pass, I can go into my post-mortem debugger, find the node where the error occured and look at the frame in which it was created. I can get a traceback from that frame using traceback.print_stack and look at the frame's local variables. I thought that I might be able to use the
frame
argument ofIPython.core.debugger.set_trace
orpdb.Pdb.set_trace
to get the kind of debugging/interpreter functionality that I really want, but theframe
argument doesn't seem to work, and may not have been intended for this kind of use case.This isn't the first time I've had an exception on the reverse pass and wanted to locate the corresponding place in the forward pass where it originated. Have others had this issue? Any thoughts on a solution?