huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.64k stars 26.92k forks source link

Bert: relative_key position embedding causes error in encoder-decoder setup #14010

Closed dvruette closed 2 years ago

dvruette commented 3 years ago

Environment info

Who can help

@LysandreJik @patrickvonplaten

Information

Model I am using (Bert, XLNet ...): BertModel, EncoderDecoderModel

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior:

  1. Copy/paste the code snippet from below
  2. Run the script
import torch
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel

config = {
  'hidden_size': 512,
  'num_attention_heads': 8,
  'position_embedding_type': 'relative_key_query'
}
encoder_config = BertConfig(**config)
decoder_config = BertConfig(**config)
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)

model = EncoderDecoderModel(config)
model.config.decoder.is_decoder = True
model.config.decoder.add_cross_attention = True

batch_size, src_len, tgt_len = 1, 2, 3
x = torch.zeros(batch_size, src_len).int()
y = torch.zeros(batch_size, tgt_len).int()

model(input_ids=x, decoder_input_ids=y)

Expected behavior

The above code snippet is expected to run without errors.

Instead, it produces error [1] exactly if src_len == tgt_len. This breaks any setup where source sequences may have a different length than the target sequence, which includes my setup. The same error occurs for the relative_key position embedding.

The problem can be circumvented by padding the sequences to be the same length, but this is not a good solution with respect to performance, e.g. if the source sequence is much longer than the target sequence.

Error [1]:

[...]
File "~/opt/miniconda3/envs/symbolic-music/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 306, in forward
    relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  File "~/opt/miniconda3/envs/symbolic-music/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: einsum(): operands do not broadcast with remapped shapes [original->remapped]: [1, 8, 2, 64]->[1, 8, 1, 2, 64] [3, 3, 64]->[1, 1, 3, 3, 64]
qqaatw commented 3 years ago

I think this can be solved by setting a distance matrix with a shape (target sequence length, source sequence length) for the relative position embedding, although I'm not sure whether this approach is hypothetically reasonable to cross-attention. (The original paper focuses on self-attention)

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    seq_length_l = hidden_states.size()[1]
    seq_length_r = encoder_hidden_states.size()[1] if is_cross_attention else hidden_states.size()[1]
    position_ids_l = torch.arange(seq_length_l, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    position_ids_r = torch.arange(seq_length_r, dtype=torch.long, device=hidden_states.device).view(1, -1)
    distance = position_ids_l - position_ids_r
    positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
    positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

    if self.position_embedding_type == "relative_key":
        relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores
    elif self.position_embedding_type == "relative_key_query":
        relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
        relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
        attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

cc @patrickvonplaten

patrickvonplaten commented 3 years ago

Sorry, to reply only now!

Just attached a PR that should fix the problem. IMO, cross attention layers should never make use of positional encodings as they don't really make sense there. E.g. T5 uses relative position encodings as well and simply disables it for the cross attention layers: https://github.com/huggingface/transformers/blob/a3ded170e22b37027dab456a12ff2f523c99d998/src/transformers/models/t5/modeling_t5.py#L563

patrickvonplaten commented 3 years ago

Let me know what you guys think @qqaatw and @dvruette !