szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.24k stars 279 forks source link

How to draw a graph with multiple outputs? #16

Closed stoneyang closed 6 years ago

stoneyang commented 6 years ago

Hi, @szagoruyko

Thanks for your repo! pytorchviz has made network drafting and demonstration much more convenient.

From #15 , your said that multiple outputs was supported then.

But in constructing a net as the following, I keeps receiving the error like #10 .

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchviz import make_dot

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.fc1_2 = nn.Linear(320, 60)
        self.fc2_2 = nn.Linear(60, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        y1 = F.relu(self.fc1(x))
        y1 = F.dropout(y1, training=self.training)
        y1 = self.fc2(y1)
        y2 = F.relu(self.fc1_2(x))
        y2 = F.dropout(y2, training=self.training)
        y2 = self.fc2_2(y2)
        # case1: works, but single output
        y = F.log_softmax(y1, dim=1) + F.log_softmax(y2, dim=1)
        return y
        # case2: fails, like issue #10 
        return F.log_softmax(y1, dim=1), F.log_softmax(y2, dim=1)

model = Net()

x = torch.randn(1, 1, 28, 28).requires_grad_(True)
# for case1
y = model(x)
g = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
# for case2
y1, y2 = model(x) 
y = (y1, y2) # to form a tuple
g = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
g.view()

And the traceback:

Traceback (most recent call last):
  File "net_test.py", line 39, in <module>
    g = make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)]))
  File "/usr/local/lib/python3.6/site-packages/torchviz/dot.py", line 60, in make_dot
    add_nodes(var.grad_fn)
AttributeError: 'tuple' object has no attribute 'grad_fn'
szagoruyko commented 6 years ago

@stoneyang your snippet works fine for me, update your torchviz:

pip install git+https://github.com/szagoruyko/pytorchviz
stoneyang commented 6 years ago

@szagoruyko Thanks for your timely reply!

I did install the package on my MAC with pytorch-0.4.1 installed.

The problem still exists...

P.S.: I've added the trackback and hope it helps. :)

stoneyang commented 6 years ago

The installation messages:

Collecting git+https://github.com/szagoruyko/pytorchviz
  Cloning https://github.com/szagoruyko/pytorchviz to /private/var/folders/db/qjdq4psd0yvfc9cczfbhqcrw0000gn/T/pip-req-build-0ksa_wni
Requirement already satisfied (use --upgrade to upgrade): pytorchviz==0.0.1 from git+https://github.com/szagoruyko/pytorchviz in /usr/local/lib/python3.6/site-packages
Requirement already satisfied: torch in /usr/local/lib/python3.6/site-packages (from pytorchviz==0.0.1) (0.4.1)
Requirement already satisfied: graphviz in /usr/local/lib/python3.6/site-packages (from pytorchviz==0.0.1) (0.8.4)
Building wheels for collected packages: pytorchviz
  Running setup.py bdist_wheel for pytorchviz ... done
  Stored in directory: /private/var/folders/DUMMY
Successfully built pytorchviz
stoneyang commented 6 years ago

Oh god, I solved it finally .... by just upgrading pytorchviz by insert a more option:

pip install -U git+https://github.com/szagoruyko/pytorchviz

I'll close this and leave it as a reminder.

Thanks for your time. :)