idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

Expected usage of `length_masks` in `TransformerEncoder.forward` #46

Closed xvr-hlt closed 4 years ago

xvr-hlt commented 4 years ago

Hi, unsure how to use length_masks for either softmax/full attention or linear attention TransformerEncoder models. In the event that this parameter is not supported for these models, it would be great to get an informative error message. Usage:

import torch
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import LengthMask

# Create the builder for our transformers
builder = TransformerEncoderBuilder.from_kwargs(n_layers=8,
                                                n_heads=8,
                                                query_dimensions=64,
                                                value_dimensions=64,
                                                feed_forward_dimensions=1024)

# Build a transformer with softmax attention
builder.attention_type = "full"
softmax_model = builder.get()

# Build a transformer with linear attention
builder.attention_type = "linear"
linear_model = builder.get()

# Construct the dummy input
X = torch.rand(10, 128, 8 * 64)

# Construct the length array corresponding to all elements being length 64
lengths = torch.Tensor([64] * 10).long()  # tensor([64, 64, 64, 64, 64, 64, 64, 64, 64, 64])
length_mask = LengthMask(lengths)

y = softmax_model(X, length_mask=length_mask)

Results in the error:

.../fast_transformers/attention/full_attention.py in forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths)
     66         if not attn_mask.all_ones:
     67             QK = QK + attn_mask.additive_matrix
---> 68         QK = QK + key_lengths.additive_matrix[:, None, None]
     69 
     70         # Compute the attention and the weighted average

RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 3

And ditto with y = linear_model(X, length_mask=length_mask).

angeloskath commented 4 years ago

Hi,

I agree that we should find a way to improve the error message here. What is happening is that the mask that you create using the LengthsMask has a maximum allowed size of 64 since you do not set it explicitly it assumes that it is going to be max(lengths).

You can check our docs for masking.

So that means that you can fix the code above simply with

length_mask = LengthMask(lengths, max_len=128)

You can also think of it as follows, when you pad you wouldn't pad beyond what is necessary right? So usually you pad up to lengths.max() and not 2*lengths.max() as it is the case here. However the error message is definitely not helpful.

Cheers, Angelos

xvr-hlt commented 4 years ago

@angeloskath cheers and thanks for pointing me in the right direction :)

FWIW from my perspective an ideal interface for length masking wouldn't require specifying max length – I feel like padding to a fixed length (which means the longest non-padded sequence is batch-dependent and not lengths.max())is at least somewhat common and a little annoying in the current formulation. Definitely still workable though!