pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
83.53k stars 22.54k forks source link

JITed program is getting stuck #18320

Closed hhsecond closed 5 years ago

hhsecond commented 5 years ago

🐛 Bug

I have a program that's JITed using both script and trace. While the non-JITed version executes without any problem, JITed version gets stuck in a while loop.

To Reproduce

Below snippet is the complete script (it's a variant of PyTorch Chatbot example given in the doc). So the issue I have is with the wrapper function I guess. It's not returning control and getting stuck in the while loop inside the wrapper function itself

import json

import torch
import torch.nn as nn
import torch.nn.functional as F

PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

# TODO: `.to(device=device)` for all tensors

class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, n_layers=1, dropout=0):
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(voc.num_words, hidden_size)
        self.gru = nn.GRU(
            hidden_size, hidden_size, n_layers,
            dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
        t1 = outputs[:, :, :self.hidden_size]
        t2 = outputs[:, :, self.hidden_size:]
        outputs = t1 + t2
        return outputs, hidden

class Attn(nn.Module):
    def __init__(self, hidden_size):
        self.hidden_size = hidden_size

    def forward(self, hidden, encoder_output):
        attn_energies = torch.sum(hidden * encoder_output, dim=2)
        attn_energies = attn_energies.t()
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

class LuongAttnDecoderRNN(nn.Module):
    def __init__(
            self, hidden_size,
            output_size, n_layers=1, dropout=0.1):
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        # Define layers
        self.embedding = nn.Embedding(voc.num_words, hidden_size)
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(
            hidden_size, hidden_size, n_layers,
            dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(hidden_size)

    def forward(self, input_seq, last_hidden, encoder_outputs):
        embedded = self.embedding(input_seq)
        embedded = self.embedding_dropout(embedded)
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input =, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        return output, hidden

class Voc:
    def __init__(self, name): = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            PAD_token: "PAD",
            SOS_token: "SOS",
            EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

voc = Voc(name=None)
voc.num_words = 7826  # TODO - change this hardcoding after debugging

hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1

encoder = EncoderRNN(hidden_size, encoder_n_layers, dropout)
seq = torch.ones((MAX_LENGTH, 1), dtype=torch.long)
seq_length = torch.tensor([seq.size()[0]])
traced_encoder = torch.jit.trace(encoder, (seq, seq_length))

decoder = LuongAttnDecoderRNN(
    hidden_size, voc.num_words, decoder_n_layers, dropout)
test_encoder_outputs, test_encoder_hidden = traced_encoder(seq, seq_length)
test_decoder_hidden = test_encoder_hidden[:decoder.n_layers]
test_decoder_input = torch.LongTensor(1, 1).random_(0, voc.num_words)
traced_decoder = torch.jit.trace(
    decoder, (test_decoder_input, test_decoder_hidden, test_encoder_outputs))

def wrapper(input_seq, input_length):
    PAD_token = 0
    SOS_token = 1
    EOS_token = 2
    max_length = 10
    n_layers = 2
    e_outputs, e_hidden = traced_encoder(input_seq, input_length)
    d_hidden = e_hidden[:n_layers]
    d_input = torch.ones(1, 1, dtype=torch.long)
    d_input *= SOS_token
    #TODO - put EOS check somehwo
    all_tokens = torch.zeros([0], dtype=torch.long)
    while max_length > 0:
        max_length -= 1
        d_output, d_hidden = traced_decoder(d_input, d_hidden, e_outputs)
        _, d_input = torch.max(d_output, dim=1)
        all_tokens =, d_input), dim=0)
        d_input = torch.unsqueeze(d_input, 0)
    return all_tokens

def run():
    indexes_batch = [[787, 572, 2]]  # "hello sir + EOS"
    lengths = torch.tensor([3])
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
    tokens = wrapper(input_batch, lengths)


Expected behavior

Non-JITed version returns the token in less than a second and I would expect the same from JITed version.


PyTorch version: 1.0.1.post2 Is debug build: No CUDA used to build PyTorch: None

OS: Ubuntu 18.10 GCC version: (Ubuntu 8.2.0-7ubuntu1) 8.2.0 CMake version: version 3.12.1

Python version: 3.7 Is CUDA available: No CUDA runtime version: No CUDA GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA

Versions of relevant libraries: [pip] numpy==1.16.2 [pip] torch==1.0.1.post2 [pip] torchvision==0.2.2 [conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch-cpu 1.0.1 py3.7_cpu_2 pytorch [conda] torchvision-cpu 0.2.2 py_3 pytorch

eellison commented 5 years ago

This looks to be an infinite loop in requires_grad analysis, looking into it

hhsecond commented 5 years ago

@eellison Thanks a lot for looking into the issue. Would you mind giving me a bit more info about the issue?

eellison commented 5 years ago

@hhsecond I'm landing the change now so hopefully you won't be affected shortly.

The issue was in our requires_grad analysis, we assumed that a loop output and loop input would converge to both requiring grad or not requiring it.

In the test example

    def test_requires_grad_loop(self):
        def test(x, y, z):
            # type: (Tensor, Tensor, int) -> Tensor
            for _ in range(z):
                x = y
            return x

The loop input is the value of x when we enter the loop and the output is the value of x when we exit. If x requires grad but y doesn't, then the loop input will require grad but the loop output won't.

This was triggered in your example (iirc) because we set d_input to require grad on the input to the loop but not when it exited (since torch.max returns an integral tensor which can't require grad).

eellison commented 5 years ago

@hhsecond I'm landing the change now so hopefully you won't be affected shortly.

The issue was in our requires_grad analysis, we assumed that a loop output and loop input would converge to both requiring grad or not requiring it.

In the test example

    def test_requires_grad_loop(self):
        def test(x, y, z):
            # type: (Tensor, Tensor, int) -> Tensor
            for _ in range(z):
                x = y
            return x

The loop input is the value of x when we enter the loop and the output is the value of x when we exit. If x requires grad but y doesn't, then the loop input will require grad but the loop output won't.

This was triggered in your example (iirc) because we set d_input to require grad on the input to the loop but not when it exited (since torch.max returns an integral tensor which can't require grad).

If you set d_input to not require grad before the loop i think the error would no longer happen.

hhsecond commented 5 years ago

Great, thanks a lot for the explanation. I did try with torch.no_grad() but apparently, we can't do that inside the JIT yet. I haven't tried .required_grad=True. Will give it a shot.

hhsecond commented 5 years ago

Hi @eellison, I tried making requires_grad=False but apparently Aten::ones don't understand that keyword argument and raised keyword argument requires_grad unknown. I guess your fix is the only way to go then!

eellison commented 5 years ago

@hhsecond you can also try .detach(), I think that would work.

hhsecond commented 5 years ago

@eellison So I think we got misled a bit. d_input is by default requires_grad=False on the creation itself. The return value from the encoder and decoder has requires_grad=True. I tried detaching them and did not work.