idiap / fast-transformers

Pytorch library for fast transformer implementations
1.65k stars 179 forks source link

TypeError: forward() missing 2 required positional arguments: 'query_lengths' and 'key_lengths' #67

Closed mHsuann closed 3 years ago

mHsuann commented 3 years ago

Hi, I tried to implement 'clustered attention' in my Transformer model, and the code of my model is following by 'The Annotated Transformer.' Then, I encountered some problems while implementing the code: the self-attenion module in Annotated Tranformer is:

class MultiHeadAttention(nn.Module):

    def __init__(self, h, d_model, dropout = 0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model%h == 0 
        self.d_k = d_model//h
        self.h = h  
        self.linears = clones(nn.Linear(d_model, d_model), 4) 
        self.attn = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask = None):
        if mask is not None:
            mask = mask.unsqueeze(1) 
        nbatches = query.size(0)

        query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l,x in zip(self.linears, (query, key, value))]

        x, self.attn = attention(query, key, value, mask = mask, dropout = self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h*self.d_k)
        return self.linears[-1](x)

However, the clustered attention module is:

class ClusteredAttention(Module):
    ...
    def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths):
       ...

So, I got the TypeError.

TypeError: forward() missing 2 required positional arguments: 'query_lengths' and 'key_lengths'

I would like to ask where does query_lenghs come from? I only know that it got it through lengths(self) (Sorry for my limited English.)

Thank You !!!

angeloskath commented 3 years ago

Hi,

You can get a pretty thorough explanation in our docs https://fast-transformers.github.io/attention/ .

TL;DR: The query lengths is a mask that defines the number of queries in each sequence in the batch. The key lengths defines the number of keys and the attention mask defines where each query can attend to. For a transformer encoder, the query lengths and key lengths should be the same.

I am closing the comment but feel free to reopen it if you have more questions.

Cheers, Angelos