lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Performance gain replacing original attention to fast attention in this repo? #52

Open phypan11 opened 3 years ago

phypan11 commented 3 years ago

I find a space-wise performance gain when I have long sequence with small feature dimension for sure.

But I do not find any time-wise performance gain, and under the same condition, I find loss drop is slower under the same conditions.

This is probably because my dataset and task does not fit right for gaining performance by using performer(sequence not long enough, task not hard enough, etc.),

but I just want to know if anyone experienced tangible performance gain by replacing to performer-version fast attention.

For your information, I merely replaced this simple attention mechanism:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) \
             / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim = -1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

to FastAttention in this repo.

Thank you!

RaivoKoot commented 3 years ago

Would also like to know.

I will do tests on my current default transformer model which uses inputs with length ~2000 and report my performance/speed results here.

pzzhang commented 3 years ago

@lucidrains I recently used your implementation of performer (https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/performer.py) of linformer (https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/linformer.py) to compare different efficient attention mechanisms in image classification and object detection tasks. See the results reported here: https://github.com/microsoft/vision-longformer. Thank you for your excellent open-sourced code!

@phypan11 @RaivoKoot You may be interested in the results, too.