szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.24k stars 279 forks source link

Incorrect graph for saved variable using custom autograd function #60

Closed Varal7 closed 3 years ago

Varal7 commented 3 years ago

In the following example, the saved variable for each FixedGradientFunctionBackward Node should be different, but they are merged into a single dot.node

https://colab.research.google.com/drive/1MyOV58n6oex9X_Z5HSRf-Gb87dclYNKN?usp=sharing

from torch.autograd import Function
from torchviz import make_dot

class FixedGradientFunction(Function):
    @staticmethod
    def forward(ctx, x, grad_x):
        ctx.save_for_backward(grad_x)
        return x

    @staticmethod
    def backward(ctx, grad_x):
        saved_grad_x, = ctx.saved_tensors
        return saved_grad_x, None

fn = FixedGradientFunction

x = torch.randn(1, requires_grad=True)

dense_1 = torch.rand(1)
dense_2 = torch.rand(1)

z = (fn.apply(x, dense_1) + fn.apply(x, dense_2)).sum()
# z.backward()
make_dot(z)
albanD commented 3 years ago

I think this is fixed by https://github.com/pytorch/pytorch/pull/56017 if you use PyTorch nightly. Could you double check?

Varal7 commented 3 years ago

The bug is still there as of commit 935057fc7464d0df6741ffc24d5aed3131533073 Author: Lily Johnson lillianjohnson@fb.com Date: Tue Jun 8 08:01:01 2021 -0700

image

albanD commented 3 years ago

Ho yes, my bad. These c++ tensors are temporary as well. I guess your SavedVariable improvements will avoid this problem.

soulitzer commented 3 years ago

Ho yes, my bad. These c++ tensors are temporary as well. I guess your SavedVariable improvements will avoid this problem.

Another quick fix could also be just to save a reference ourselves, just like what we already have with seen w/ non custom functions

albanD commented 3 years ago

In general, I think we should do that yes. use of id() is only valid as long as the PyObject is alive.