szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.24k stars 279 forks source link

'NotImplementedError:' when passing multiple input to model() #70

Closed kiristern closed 2 years ago

kiristern commented 2 years ago

Hi, Sorry if this is super trivial, but the error message isn't too clear on what exactly is wrong... I have:

class VGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VGCNEncoder, self).__init__()
        """
        First GCN layer generates a lower-dimensional ft matrix;
        Second GCN layer generates mu and log sigma sq.
        """
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True) # cached only for transductive learning
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logvar = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)

and am trying to visualize the model using make_dot but when creating the dummy input:

batch = next(iter(data.x)).requires_grad_(True)
yhat = model(batch, train_data.edge_index)

I get the following error:

/tmp/ipykernel_28889/3816987256.py in <module>
----> 1 yhat = model(batch, train_data.edge_index) #give dummy batch to forward()

~/miniconda3/envs/VGAE_graphviz/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/envs/VGAE_graphviz/lib/python3.8/site-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
    199         registered hooks while the latter silently ignores them.
    200     """
--> 201     raise NotImplementedError
    202 
    203 

NotImplementedError: 

are the model inputs not x and edge_index ??