Open AutuanLiu opened 5 years ago
Thanks for the report. Admittedly, I haven't tested with RNNs. Would you mind sharing the sample code used to generate this?
import torch
import hiddenlayer as hl
from torch import nn
class RNN_Net(nn.Module): def init(self, input_dim, hidden_dim, output_dim, num_layers=1): super().init() self.hidden_dim = hidden_dim self.num_layers = num_layers self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
hidden = self.initHidden(x.size(0))
y, _ = self.rnn(x, hidden)
return self.fc(y[:, -1, :])
def initHidden(self, batchsize):
weight = next(self.parameters())
h0 = weight.new_zeros(self.num_layers, batchsize, self.hidden_dim)
return (h0, h0)
model = RNN_Net(5, 15, 5) hl.build_graph(model, torch.zeros([32, 20, 5]))
2. The graph of neural network with RNN(LSTM).
![image](https://user-images.githubusercontent.com/15994006/48198998-b15b8f00-e395-11e8-9bf5-4ac79141ec41.png)
3. Warning
UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
4. **LSTM** can be rendered when `batch_first=False`.
@AutuanLiu I want to give an update on this. Thank you for raising the issue and providing details. And sorry it took a long time.
I ran experiments with LSTMs and tried to find a simple pattern that makes it easy to render them. Unfortunately, I couldn't find a simple solution. And the method you used works in a subset of the cases only. Also, PythonOp
can represent other operations.
So I think your solution is useful for your case, but shouldn't be built into the library as it might cause unintended problems for other network types. At this point, I'm afraid I don't have a good solution for LSTMs. I hope I manage to find some free time in the near feature to dive deeper into how PyTorch represents LSTMs and find a general solution that works in all (or most) cases.
Thank you!
We need
hl.transforms.Rename()
to rename the RNN node.