onnx / onnx-tensorflow

Tensorflow Backend for ONNX
Other
1.28k stars 296 forks source link

Convert onnx multi-layer LSTM bug #878

Open aishoot opened 3 years ago

aishoot commented 3 years ago

Describe the bug Hello, I can convert one-layer LSTM and multilayer LSTMs to tensorflow. However, the results between onnx and onnx-tf are not equal when using multi-layer LSTM, while the results between onnx and onnx-tf when using single-layer LSTM are the same.

Python, ONNX, ONNX-TF, Tensorflow version

This section can be obtained by running get_version.py from util folder.

Additional context Pytorch code:

import torch 
import torch.nn as nn 
import torchvision.datasets as dsets 
import torchvision.transforms as transforms 
from torch.autograd import Variable 

num_layers = 2  # !!! converted successfully when num_layers = 1
sequence_length = 28 
input_size = 28
hidden_size = 128
num_classes = 10
batch_size = 100
num_epochs = 2 
learning_rate = 0.01 

train_dataset = dsets.MNIST(root='data', 
                        train=True,
                        transform=transforms.ToTensor(),
                        download=True)
test_dataset = dsets.MNIST(root='data',
                        train=False,
                        transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)  # bidirectional=True
        self.fc   = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)   # , (self.h0, self.c0)
        out    = self.fc(out[:, -1, :]) #(batch, seq, num_classes)
        return out

rnn = RNN(input_size, hidden_size, num_layers, num_classes)
rnn.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = Variable(images.view(-1, sequence_length, input_size)).cuda()
        labels = Variable(labels).cuda()
        optimizer.zero_grad()
        outputs = rnn(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size, loss.item()))
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images.view(-1, sequence_length, input_size)).cuda()    
    outputs = rnn(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted.cpu() == labels).sum()
print('Test Accuracy of the model on the 1000 test images: %d %%' %(100 * correct / total))

torch.save(rnn.state_dict(), "mnist_rnn.pth")
shocoladka commented 3 years ago

As I mentioned in https://github.com/onnx/onnx-tensorflow/issues/796 this is probably due to

https://github.com/onnx/onnx-tensorflow/blob/c63d4351c7752a769cdc9a1bfcf79ffd140e0e6a/onnx_tf/handlers/backend/rnn_mixin.py#L28-L36

rnn_cell being a global class variable @chinhuang007 It would be really nice to have this fixed in the new release

aishoot commented 3 years ago

@shocoladka will this bug be fixed in the new release? When? good news

shocoladka commented 3 years ago

@aishoot No idea 😄

chinhuang007 commented 3 years ago

This issue is not trivial. It will take some time to investigate a fix. So I think the 1.8 release will not include it unless someone from community can jump in and fix it in a few days, which would be very much appreciated.

aishoot commented 3 years ago

@shocoladka thanks

nightfuryyy commented 2 years ago

@shocoladka did you fix that bug ?

nightfuryyy commented 2 years ago

i have just fixed it. you just need to comment "if cls.rnn_cell is None: " line and everything will be ok.