huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.11k stars 27.04k forks source link

Unexpected shape of "past_key_values" of ProphetNetDecoder.forward() #15362

Closed meguruin closed 2 years ago

meguruin commented 2 years ago

Environment info

Who can help

@patrickvonplaten, @LysandreJik

Information

Model I am using (Bert, XLNet ...): ProphetNet

The problem arises when using:

When I try to convert ProphetNetModel to onnx, I found that "past_key_values" of decoder output is not the same shape as in official document. The description of ProphetNetDecoderModelOutput says that:

past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) — List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_attn_heads, decoder_sequence_length, embed_size_per_head)).

However, I get past_key_values just like BaseModelOutputWithPastAndCrossAttentions.

past_key_values (tuple(tuple(torch.FloatTensor)), optional, returned when use_cache=True is passed or when config.use_cache=True) — Tuple of tuple(torch.FloatTensor) of length config.n_layers, with each tuple having 2 tensors of shape (batch_size, num_heads, sequence_length, embed_size_per_head)) and optionally if config.is_encoder_decoder=True 2 additional tensors of shape (batch_size, num_heads, encoder_sequence_length, embed_size_per_head).

The tasks I am working on is:

To reproduce

from transformers import ProphetNetTokenizer, ProphetNetEncoder, ProphetNetDecoder

tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
encoder = ProphetNetEncoder.from_pretrained('microsoft/prophetnet-large-uncased')
decoder = ProphetNetDecoder.from_pretrained('microsoft/prophetnet-large-uncased')
# assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
enc_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
dec_inputs = tokenizer("<s>", return_tensors="pt")
encoder_outputs = encoder(
    input_ids=enc_inputs["input_ids"],
    attention_mask=enc_inputs["attention_mask"],
    return_dict=True,
)
decoder_outputs = decoder(
    input_ids=dec_inputs["input_ids"],
    encoder_hidden_states=encoder_outputs["last_hidden_state"],
    encoder_attention_mask=enc_inputs["attention_mask"],
    past_key_values=None,
    return_dict=True,
)
print(decoder_outputs.keys())
# odict_keys(['last_hidden_state', 'last_hidden_state_ngram', 'past_key_values'])

print(len(decoder_outputs["past_key_values"])) 
# 12
print(len(decoder_outputs["past_key_values"][0]))
# 4
print(decoder_outputs["past_key_values"][0][0].shape)
# torch.Size([1, 16, 4, 64])

Expected behavior

len(decoder_outputs["past_key_values"]) == 12 and decoder_outputs["past_key_values"][0].shape = (2, batch_size, num_attn_heads, decoder_sequence_length, embed_size_per_head))

patrickvonplaten commented 2 years ago

For ONNX related issues, gently pinging @lewtun here

meguruin commented 2 years ago

@patrickvonplaten, thanks. I think that this issue relates not only using ONNX because it is contradiction between document and actual outputs.

patrickvonplaten commented 2 years ago

Ok, I see so the problem is the documentation here? There might very well be a bug in the documentation... The following looks correct to me:

print(len(decoder_outputs["past_key_values"])) 
# 12
print(len(decoder_outputs["past_key_values"][0]))
# 4
print(decoder_outputs["past_key_values"][0][0].shape)
# torch.Size([1, 16, 4, 64])

We have 12 layers. ProphetNet is an encoder-decoder model so it caches 4 tensors (decoder value, decoder key as well as projected encoder key and projected encoder value matrices for the cross attention layer).

Would you like to open a PR maybe to fix the documentation is you've found the bug it seems?

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.