Closed dvruette closed 2 years ago
I think this can be solved by setting a distance matrix with a shape (target sequence length, source sequence length)
for the relative position embedding, although I'm not sure whether this approach is hypothetically reasonable to cross-attention. (The original paper focuses on self-attention)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
seq_length_l = hidden_states.size()[1]
seq_length_r = encoder_hidden_states.size()[1] if is_cross_attention else hidden_states.size()[1]
position_ids_l = torch.arange(seq_length_l, dtype=torch.long, device=hidden_states.device).view(-1, 1)
position_ids_r = torch.arange(seq_length_r, dtype=torch.long, device=hidden_states.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
if self.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
cc @patrickvonplaten
Sorry, to reply only now!
Just attached a PR that should fix the problem. IMO, cross attention layers should never make use of positional encodings as they don't really make sense there. E.g. T5 uses relative position encodings as well and simply disables it for the cross attention layers: https://github.com/huggingface/transformers/blob/a3ded170e22b37027dab456a12ff2f523c99d998/src/transformers/models/t5/modeling_t5.py#L563
Let me know what you guys think @qqaatw and @dvruette !
Environment info
transformers
version: 4.5.1Who can help
@LysandreJik @patrickvonplaten
Information
Model I am using (Bert, XLNet ...): BertModel, EncoderDecoderModel
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
Expected behavior
The above code snippet is expected to run without errors.
Instead, it produces error [1] exactly if
src_len == tgt_len
. This breaks any setup where source sequences may have a different length than the target sequence, which includes my setup. The same error occurs for therelative_key
position embedding.The problem can be circumvented by padding the sequences to be the same length, but this is not a good solution with respect to performance, e.g. if the source sequence is much longer than the target sequence.
Error [1]: