allenai / longformer

Longformer: The Long-Document Transformer
https://arxiv.org/abs/2004.05150
Apache License 2.0
2.04k stars 273 forks source link

Error when converting MBart to Longformer #206

Open edgartanaka opened 3 years ago

edgartanaka commented 3 years ago

Hi @ibeltagy!

I'm trying to convert MBart-50 to a Longformer version. I posted my code here: https://gist.github.com/edgartanaka/0d69b50e39f96cb0738f9808d48158a2 The base of this code was your script https://github.com/allenai/longformer/blob/caefee668e39cacdece7dd603a0bebf24df6d8ca/scripts/convert_bart_to_longformerencoderdecoder.py (BTW thanks for sharing!) Although your code ran successfully with an older version of transformers (v3.1.0), when I upgrade to transformers v4.4.0, I start having some issues. I was able to move past some errors but now I'm stuck in this one. Any idea what the problem might be?

Traceback (most recent call last):
  File "/Users/edgart/git/podcast-summarization-edgart/experiments/long4/convert_bart_to_longformerencoderdecoder.py", line 183, in <module>
    main()
  File "/Users/edgart/git/podcast-summarization-edgart/experiments/long4/convert_bart_to_longformerencoderdecoder.py", line 179, in main
    summary_test(args)
  File "/Users/edgart/git/podcast-summarization-edgart/experiments/long4/convert_bart_to_longformerencoderdecoder.py", line 130, in summary_test
    summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=50, early_stopping=True)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/generation_utils.py", line 916, in generate
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/generation_utils.py", line 411, in _prepare_encoder_decoder_kwargs_for_generation
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(input_ids, return_dict=True, **encoder_kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 799, in forward
    layer_outputs = encoder_layer(
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 313, in forward
    hidden_states, attn_weights, _ = self.self_attn(
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/edgart/git/podcast-summarization-edgart/experiments/long4/longformer_encoder_decoder.py", line 60, in forward
    hidden_states=hidden_states,  # I'm guessing I just need to pass
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py", line 603, in forward
    diagonal_mask = self._sliding_chunks_query_key_matmul(
  File "/Users/edgart/.pyenv/versions/long4/lib/python3.9/site-packages/transformers/models/longformer/modeling_longformer.py", line 799, in _sliding_chunks_query_key_matmul
    batch_size, seq_len, num_heads, head_dim = query.size()
ValueError: too many values to unpack (expected 4)

Thanks! Edgar

helmosor commented 3 years ago

Hi Edgar

adding this : attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) before feeding 'attention_mask' to the forward method seemed to solve the issue on my side!

SCNUJackyChen commented 2 years ago

I think the reason is huggingface transformers source codes has been updated, the main difference is that they removed the key_padding_mask (with shape [batch_size, key_len]) and added attention_mask (with shape [batch_size, 1, src_len, tgt_len]). So the way to make them compatible is to modify the source code of huggingface.

I had a try on this, using the facebook/mbart-large-50-one-to-many checkpoint, you may check my code here