facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.37k stars 6.4k forks source link

Multi-head attention module throwing RuntimeError: view size is not compatible with input tensor's size and stride #2598

Open nate-bush opened 4 years ago

nate-bush commented 4 years ago

🐛 Bug

Multi-head attention module throwing error: "RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead."

To Reproduce

  1. Train a model using a TPU machine

    python train.py \ my/redacted/data/dir/ \ --arch transformer_iwslt_de_en \ --task translation \ --source-lang zh \ --target-lang en \ --lr 0.0005 \ --weight-decay 0.0001 \ --optimizer adam \ --adam-betas '(0.9, 0.98)' \ --lr-scheduler inverse_sqrt \ --min-lr 1e-09 \ --warmup-updates 4000 \ --warmup-init-lr 1e-07 \ --label-smoothing 0.1 \ --clip-norm 5.0 \ --criterion label_smoothed_cross_entropy \ --dropout 0.1 \ --max-tokens 48000 \ --save-dir data/mt/output/checkpoints \ --validate-interval 1 \ --tensorboard-logdir output/tensorboard_logs \ --no-progress-bar \ --log-interval 100 \ --max-source-positions 1024 \ --max-target-positions 1024 \ --skip-invalid-size-inputs-valid-test \ --no-epoch-checkpoints \ --max-epoch 20 \ --distributed-world-size 8 \ --tpu \ --num-batch-buckets 5

  2. Generate results (reproduced on both TPU, GPU, and CPU)

    python /anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq_cli/generate.py \ my/redacted/data/dir/ \ --task translation \ --path my/redacted/model/dir/checkpoints/checkpoint_best.pt \ --beam 5 \ --remove-bpe \ --skip-invalid-size-inputs-valid-test \ --scorer sacrebleu \ --source-lang zh \ --target-lang en

  3. See error

On GPU...

2020-09-09 20:42:57 | WARNING | fairseq.data.data_utils | 62 samples have invalid sizes and will be skipped, max_positions=(1024, 1024), first few sample ids=[107311, 69863, 57884, 54207, 83884, 71724, 46, 47217, 64581, 106659]
Traceback (most recent call last):
  File "/opt/conda/bin/fairseq-generate", line 8, in <module>
    sys.exit(cli_main())
  File "/opt/conda/lib/python3.7/site-packages/fairseq_cli/generate.py", line 274, in cli_main
    main(args)
  File "/opt/conda/lib/python3.7/site-packages/fairseq_cli/generate.py", line 38, in main
    return _main(args, sys.stdout)
  File "/opt/conda/lib/python3.7/site-packages/fairseq_cli/generate.py", line 150, in _main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/tasks/fairseq_task.py", line 361, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/sequence_generator.py", line 159, in generate
    return self._generate(sample, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/sequence_generator.py", line 198, in _generate
    encoder_outs = self.model.forward_encoder(net_input)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/sequence_generator.py", line 697, in forward_encoder
    for model in self.models
  File "/opt/conda/lib/python3.7/site-packages/fairseq/sequence_generator.py", line 697, in <listcomp>
    for model in self.models
  File "/opt/conda/lib/python3.7/site-packages/fairseq/models/fairseq_encoder.py", line 53, in forward_torchscript
    return self.forward_non_torchscript(net_input)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/models/fairseq_encoder.py", line 62, in forward_non_torchscript
    return self.forward(**encoder_input)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/models/transformer.py", line 411, in forward
    x = layer(x, encoder_padding_mask)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/modules/transformer_layer.py", line 122, in forward
    attn_mask=attn_mask,
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/fairseq/modules/multihead_attention.py", line 342, in forward
    attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Reproduced on TPU...

2020-09-09 20:38:46 | INFO | fairseq_cli.generate | loading model(s) from redacted/path/checkpoints/checkpoint_best.pt
2020-09-09 20:38:52 | WARNING | fairseq.data.data_utils | 62 samples have invalid sizes and will be skipped, max_positions=(1024, 1024), first few sample ids=[107311, 69863, 57884, 54207, 83884, 71724, 46, 47217, 64581, 106659]
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq_cli/generate.py", line 278, in <module>
    cli_main()
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq_cli/generate.py", line 274, in cli_main
    main(args)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq_cli/generate.py", line 38, in main
    return _main(args, sys.stdout)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq_cli/generate.py", line 150, in _main
    hypos = task.inference_step(generator, models, sample, prefix_tokens)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/tasks/fairseq_task.py", line 361, in inference_step
    return generator.generate(models, sample, prefix_tokens=prefix_tokens)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/sequence_generator.py", line 159, in generate
    return self._generate(sample, **kwargs)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/sequence_generator.py", line 198, in _generate
    encoder_outs = self.model.forward_encoder(net_input)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/sequence_generator.py", line 697, in forward_encoder
    for model in self.models
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/sequence_generator.py", line 697, in <listcomp>
    for model in self.models
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/models/fairseq_encoder.py", line 53, in forward_torchscript
    return self.forward_non_torchscript(net_input)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/models/fairseq_encoder.py", line 62, in forward_non_torchscript
    return self.forward(**encoder_input)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/models/transformer.py", line 411, in forward
    x = layer(x, encoder_padding_mask)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/modules/transformer_layer.py", line 122, in forward
    attn_mask=attn_mask,
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/anaconda3/envs/torch-xla-1.6/lib/python3.6/site-packages/fairseq/modules/multihead_attention.py", line 342, in forward
    attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

NOTE: the TPU vs GPU machine is probably a red herring and the real difference is different fairseq or pytorch versions.

Code sample

See fairseq CLI commands above.

Expected behavior

No error, outputs fairseq-generate results.

Environment

Environment for running fairseq-train:

Environment for running fairseq-generate:

Additional context

The error is pretty clear so I replaced line 342 of the Multihead Attention module with

attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)

and the issue was resolved. This is a potential solution but I understand that reshape() is more expensive than view() so I'll leave it to the experts to decide.

lematt1991 commented 4 years ago

CC @myleott