waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

RNN can not display correctly #8

Open AutuanLiu opened 5 years ago

AutuanLiu commented 5 years ago

We need hl.transforms.Rename() to rename the RNN node.

tsfm = [hl.transforms.Rename(op='prim::PythonOp', to = 'LSTM')]

image

waleedka commented 5 years ago

Thanks for the report. Admittedly, I haven't tested with RNNs. Would you mind sharing the sample code used to generate this?

AutuanLiu commented 5 years ago
  1. RNN(LSTM) code:
    
    import torch
    import hiddenlayer as hl
    from torch import nn

class RNN_Net(nn.Module): def init(self, input_dim, hidden_dim, output_dim, num_layers=1): super().init() self.hidden_dim = hidden_dim self.num_layers = num_layers self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, x):
    hidden = self.initHidden(x.size(0))
    y, _ = self.rnn(x, hidden)
    return self.fc(y[:, -1, :])

def initHidden(self, batchsize):
    weight = next(self.parameters())
    h0 = weight.new_zeros(self.num_layers, batchsize, self.hidden_dim)
    return (h0, h0)

model = RNN_Net(5, 15, 5) hl.build_graph(model, torch.zeros([32, 20, 5]))


2. The graph of neural network with RNN(LSTM).
![image](https://user-images.githubusercontent.com/15994006/48198998-b15b8f00-e395-11e8-9bf5-4ac79141ec41.png)

3. Warning

UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")



4. **LSTM** can be rendered when `batch_first=False`. 
waleedka commented 5 years ago

@AutuanLiu I want to give an update on this. Thank you for raising the issue and providing details. And sorry it took a long time.

I ran experiments with LSTMs and tried to find a simple pattern that makes it easy to render them. Unfortunately, I couldn't find a simple solution. And the method you used works in a subset of the cases only. Also, PythonOp can represent other operations.

So I think your solution is useful for your case, but shouldn't be built into the library as it might cause unintended problems for other network types. At this point, I'm afraid I don't have a good solution for LSTMs. I hope I manage to find some free time in the near feature to dive deeper into how PyTorch represents LSTMs and find a general solution that works in all (or most) cases.

AutuanLiu commented 5 years ago

Thank you!