szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.24k stars 279 forks source link

make_dot_from_trace fails for very simple bivariate function #8

Closed wpeebles closed 6 years ago

wpeebles commented 6 years ago

Is there a reason why the following simple example fails? I'm having a hard time figuring it out from the error message.

from torch.autograd import Variable
import torch.nn as nn
import torch
from torchviz import make_dot_from_trace

class toy(nn.Module):

    def __init__(self):
        super(toy, self).__init__()

    def forward(self, x, y):
        return x + y

f = toy()
a = Variable(torch.FloatTensor([3.0]), requires_grad=True)
b = Variable(torch.FloatTensor([2.0]), requires_grad=True)

trace, _ = torch.jit.trace(f, args=(a,b))
make_dot_from_trace(trace)

The error message from running this is:

File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torchviz/dot.py", line 110, in make_dot_from_trace torch.onnx._optimize_trace(trace, False) TypeError: _optimize_trace() takes 1 positional argument but 2 were given

visionxyz commented 6 years ago

Same problem like yours.

szagoruyko commented 6 years ago

reinstall and change the last lines in your snippet to:

with torch.onnx.set_training(model, False):
    trace, _ = torch.jit.get_trace_graph(f, args=(a, b))
make_dot_from_trace(trace)

then it works.