HIPS / autograd

Efficiently computes derivatives of NumPy code.
MIT License
6.98k stars 911 forks source link

Debugging issues in the reverse pass #359

Open j-towns opened 6 years ago

j-towns commented 6 years ago

I'm getting underflow during the computation of the reverse pass of an RNN gradient. I have numpy raise an exception when underflow is detected. However, it's very difficult (or maybe impossible) to tell from this exception's traceback whereabouts (in the forward pass) the nodes which caused the error were created.

Ideally, I'd like to be able to start my debugger (or any interpreter) from the forward pass frame in which the node was created which caused the issue. That way I could look around and find out, for example, in which iteration of the RNN the node was created.

I've come up with a crude partial solution, which is to have nodes cache the frame in which they were created:

import inspect

class VJPDebugNode(VJPNode):
    __slots__ = ['frame']
    def __init__(*args):
        self.frame = inspect.currentframe()
        super().__init__(*args)

Now when the exception is raised during the reverse pass, I can go into my post-mortem debugger, find the node where the error occured and look at the frame in which it was created. I can get a traceback from that frame using traceback.print_stack and look at the frame's local variables. I thought that I might be able to use the frame argument of IPython.core.debugger.set_trace or pdb.Pdb.set_trace to get the kind of debugging/interpreter functionality that I really want, but the frame argument doesn't seem to work, and may not have been intended for this kind of use case.

This isn't the first time I've had an exception on the reverse pass and wanted to locate the corresponding place in the forward pass where it originated. Have others had this issue? Any thoughts on a solution?

captain-pool commented 2 years ago

Hey @j-towns is there any development in this matter? I'm facing a similar problem and looking for solution. In my case forward pass works great, but I get NaN gradients (divide by zero) during backward pass.

j-towns commented 2 years ago

Hi there, sorry for the slow reply. The short answer is that there hasn't been any concrete progress on this. The situation is slightly better in JAX, see here for a tutorial on how to debug backward pass NaNs in JAX.

One thing that is a very common cause of backward pass NaNs is functions like numpy.where and numpy.select. If there are NaNs in any of the inputs to those functions, even if they are not selected and therefore ignored in the outputs, you will still get NaNs in the backward pass. This is a basic issue that exists in some form in all array-based autodiff libraries. There's a detailed description here. The easiest workaround is known as the 'double where trick', example here.

captain-pool commented 2 years ago

Thanks for your reply. I later found out it was because of Frobenius Norm of zero-vector (which apparently is a problem in all autodiff libraries). I solved that by adding a small epsilon to the term. however, having a permanent fix like this https://github.com/pytorch/pytorch/pull/2775 would be very helpful. (I can send a PR if you think this will be a good addition).

j-towns commented 2 years ago

I can send a PR if you think this will be a good addition

Yes, that would be a good addition. If you’re willing to submit a pr I’ll merge it. It seems we discussed this and agreed it was a good idea 5 years ago but never got round to implementing it.