lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.07k stars 143 forks source link

FastAttention doesn't give results in agreement with standard attention? #69

Open simonaxelrod opened 3 years ago

simonaxelrod commented 3 years ago

Hi there,

I ran this code to compare the results of standard attention with fast attention. Surprisingly, I'm getting very large errors (about 80%). Any idea as to where this comes from?

import torch
import numpy as np
from performer_pytorch import FastAttention

num_nodes = 24
feat_dim = 128
nb_features = 8 * feat_dim

num_its = 5
errs = []

for _ in range(num_its):
    Q = torch.randn(1, 1, num_nodes, feat_dim)
    K = torch.randn(1, 1, num_nodes, feat_dim)
    V = torch.randn(1, 1, num_nodes, feat_dim)

    # fast attention

    attn = FastAttention(dim_heads=feat_dim,
                         nb_features=nb_features,
                         causal=False)

    fast = attn(q=Q,
                k=K,
                v=V)

    Q = Q.reshape(-1, feat_dim)
    K = K.reshape(-1, feat_dim)
    V = V.reshape(-1, feat_dim)

    # standard attention

    A = torch.exp(torch.matmul(Q, K.transpose(0, 1)) / feat_dim ** 0.5)
    ones = torch.ones(num_nodes)
    D_inv = torch.diag(1 / torch.matmul(A, ones))
    slow = torch.matmul(D_inv, torch.matmul(A, V))

    err = abs(slow - fast).mean() / abs(slow).mean() * 100

    errs.append(err)

mean_err = np.mean(errs)
std_err = np.std(errs)

print("Error is (%.2f +/- %.2f)%%" % (mean_err, std_err)) # prints Error is (73.28 +/- 1.99)%
simonaxelrod commented 3 years ago

@lucidrains Bumping this up again. Any thoughts on this?

wcshin-git commented 3 years ago

@simonaxelrod Does this kind of large error occur when experimenting with the original code written in Jax from Google?

simonaxelrod commented 3 years ago

I haven't tried it was Jax yet but I'll give that a shot

wcshin-git commented 3 years ago

All right, I'm curious, too:)

wcshin-git commented 3 years ago

Hi @simonaxelrod, I tried the above code using SLiM performer code which is written by the original 'Performer' authors. And also it is written in Pytorch, so I could try it easily.

import torch
import numpy as np
from slim_performer_model import MultiHeadAttention

batch = 1
num_nodes = 24 # seq_len
feat_dim = 64
n_heads = 1

num_its = 5
errs = []

for _ in range(num_its):

    # fast attention
    x = torch.randn((batch, num_nodes, feat_dim))  # x: [B, seq_len, feat_dim]
    attn = MultiHeadAttention(feature_type='favor+', n_heads=n_heads, hidden_dim=feat_dim, compute_type='iter')
    rfs = attn.sample_rfs(x.device)  # [n_heads, feat_dim, feat_dim]
    fast = attn.full_forward(x ,rfs)  # x: [B, seq_len, feat_dim] -> fast: [B, seq_len, feat_dim]

    # '_get_original_qkv' method is temporarily made by me to get the Q,K,V used in 'fast '(not in the original 'MultiHeadAttention')
    Q, K ,V = attn._get_original_qkv(x)  # -> Q, K ,V: [B, seq_len, feat_dim]  Note that this is just original Q and V,  not Q' and K'. 

    # standard attention
    A = torch.einsum('bid, bjd -> bij', Q, K) / feat_dim ** 0.5 # [B, seq_len, seq_len]
    A = torch.nn.Softmax(dim=-1)(A)
    slow = torch.einsum('bij, bjd-> bid',  A, V)  # [B, seq_len, feat_dim]

    err = (abs(slow - fast).mean() / abs(slow).mean() * 100).item()

    errs.append(err)

mean_err = np.mean(errs)
std_err = np.std(errs)

print("Error is (%.2f +/- %.2f)%%" % (mean_err, std_err)) # Error is (130.53 +/- 2.27)%

But the error is (130.53 +/- 2.27)%. I don't know why we're getting very large errors... @lucidrains, @simonaxelrod, is this normal?

gaganbahga commented 3 years ago

@simonaxelrod I observed something similar. One more observation is that in the code that you linked the error decreases by a lot (to ~1.5%) if in calculating the standard attention, we scale down the logits, for example by dividing the query matrix by 100. This probably makes the attention distribution much flatter so I guess when Q, K, V are not learned, like in this case, performer tends to produce a much flatter attention distribution compared to regular attention. I am not sure if this would hold true if this were a part of a trained neural network because then the weights might be adjusted so that this is no longer an issue. Also the authors have mentioned that "Backwards compatibility with pretrained models is available as a benefit from softmax approximation, via small finetuning...". So even though the two are compatible, it takes some finetuning to transfer the weights of standard attention to performer.

bngcode commented 2 years ago

Consider equation (7) of the paper. In this Lemma the SM+ is defined as an expectation of some exponential terms. The Expectation is of course an integral over R^d. Now what is done in the code and in the paper is that to approximate this integral, we take m=numb_features samples and orthogonalized them. Now this induced a huge error. The integral cannot be approximated sufficient good enough by taking only m points from R^d.

This explains could explain the errors by @simonaxelrod or am I missing here something?