allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.76k stars 2.25k forks source link

RuntimeError: rnn: hx is not contiguous when using Seq2SeqEncoder #1486

Closed dangitstam closed 6 years ago

dangitstam commented 6 years ago

I'm doing nmt and my model involves initializing the hidden state of the LSTM that generates the translation in the target language. For some reason, about 18% of the way through training in the first epoch, I get a PyTorch error saying that hx is not contiguous when encoding the target utterance.

I inserted asserts to ensure the hidden state i'm passing is contiguous but the asserts are never triggered and the error still happens.

I should also mention this only happens when training on the GPU, CPU training seems okay.

Here is the backtrace from stderr.log

Traceback (most recent call last):
  File "/usr/lib64/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib64/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/run.py", line 18, in <module>
    main(prog="allennlp")
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/commands/__init__.py", line 70, in main
    args.func(args)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/commands/train.py", line 103, in train_model_from_args
    args.recover)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/commands/train.py", line 133, in train_model_from_file
    return train_model(params, serialization_dir, file_friendly_logging, recover)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/commands/train.py", line 322, in train_model
    metrics = trainer.train()
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/training/trainer.py", line 707, in train
    train_metrics = self._train_epoch(epoch)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/training/trainer.py", line 480, in _train_epoch
    loss = self._batch_loss(batch, for_training=True)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/training/trainer.py", line 415, in _batch_loss
    output_dict = self._model(**batch)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dangt7/nmt/git/le-traducteur/library/models/english_to_french_mt.py", line 122, in forward
    hidden_state=fr_translation_primer)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/modules/seq2seq_encoders/pytorch_seq2seq_wrapper.py", line 83, in forward
    self.sort_and_run_forward(self._module, inputs, mask, hidden_state)
  File "/home/dangt7/nmt/git/le-traducteur/allennlp/modules/encoder_base.py", line 116, in sort_and_run_forward
    module_output, final_states = module(packed_sequence_input, initial_states)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 192, in forward
    output, hidden = func(input, self.all_weights, hx, batch_sizes)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 323, in forward
    return func(input, *fargs, **fkwargs)
  File "/home/dangt7/le-traducteur/lib/python3.6/site-packages/torch/nn/_functions/rnn.py", line 287, in forward
    dropout_ts)
RuntimeError: rnn: hx is not contiguous

Here is my source code, in particular, it's my forward() function self.en_encoder is a Seq2VecEncoder and self.fr_encoder is a Seq2SeqEncoder

        # Reverse the utterance before embedding.
        en_max_seq_len = en['tokens'].size()[-1]
        en_reversed_indices = torch.linspace(en_max_seq_len - 1, 0, en_max_seq_len).long()
        en_reversed_indices = en_reversed_indices.to(en['tokens'].device)  # CPU/GPU invariant.
        en_reversed_utterance = en['tokens'].index_select(-1, en_reversed_indices)
        assert(en['tokens'].equal(en_reversed_utterance.index_select(-1, en_reversed_indices)))
        en['tokens'] = en_reversed_utterance

        # Embed and encode the English utterance.
        # Results in a single vector representing the utterance.
        embedded_en_utterance = self.en_field_embedder(en)
        en_utterance_mask = util.get_text_field_mask(en)
        encoded_en_utterance = self.en_encoder(embedded_en_utterance, en_utterance_mask)

        # Prep the hidden state initialization of the word-level French LSTM.
        # Shape (no cell state): (num_layers, batch, en_hidden_size)
        # Shape (with cell state): Tuple of (num_layers, batch, en_hidden_size)'s
        fr_translation_primer = encoded_en_utterance.unsqueeze(0)
        fr_translation_primer = fr_translation_primer.expand(
            self._fr_encoder_num_layers,
            -1,  # Inferred from the other two.
            self._en_encoder_hidden_size
        ).contiguous()
        assert fr_translation_primer.is_contiguous()
        if self._fr_encoder_is_lstm:
            fr_translation_primer = (fr_translation_primer,
                                     torch.zeros_like(fr_translation_primer))

            assert fr_translation_primer[0].is_contiguous()
            assert fr_translation_primer[1].is_contiguous()

        # Embed and encode the French utterance.
        # Results in several vectors representing the utterance.
        # Shape: (batch, sequence_length, fr_hidden_size)
        embedded_fr_utterance = self.fr_field_embedder(fr)
        fr_utterance_mask = util.get_text_field_mask(fr)
        encoded_fr_utterance = self.fr_encoder(embedded_fr_utterance, fr_utterance_mask,
                                               hidden_state=fr_translation_primer)
dangitstam commented 6 years ago

I got past the exception by adding a contiguous() call at https://github.com/allenai/allennlp/blob/01ddd126e339971eb752ec32ba024ba1734e3b71/allennlp/modules/encoder_base.py#L107

I changed the line to

initial_states = [state.index_select(1, sorting_indices)[:, :num_valid, :].contiguous()
                           for state in hidden_state]

since asserting contiguousness from my model wasn't enough.

I don't think index_select() guarantees a contiguous tensor but PyTorch LSTMs require them to be, which explains the exception I was getting.

DeNeutoy commented 6 years ago

Hey, nice one figuring that out - it looked a bit complicated! Would you mind sending a PR?