mlpen / Nystromformer

Apache License 2.0
356 stars 41 forks source link

Possible bug in iterative_inv #3

Closed sbodenstein closed 3 years ago

sbodenstein commented 3 years ago

When the input of iterative_inv is a 4-tensor (which is how it is used in NystromAttention), then the calculation of of the initial value of V in iterative_inv here takes the max over all dimensions of the 4-tensor K rather than just the matrix dimensions as claimed in lemma 1:

image

This has a significant effect on the accuracy of the approximation on random data I've tried it on.

yyxiongzju commented 3 years ago

Hi @sbodenstein, Z_0 = A_S^T/(||A_S||_1||AS||{\infty}. We did not update the draft. The line of code is doing sum first and then max. Can you read that line again?

sbodenstein commented 3 years ago

@yyxiongzju: OK maybe I'm not being clear enough or misunderstanding something. So the input to iterative_inv is kernel_2 which has dimensions (batch, num_heads, num_landmarks, num_landmarks). We want to find inverses for each of the (num_landmarks, num_landmarks) matrices (and there are batch × num_heads of them). For each of these matrices, we will have a different value of ||A_s||_1 and ||A_s||_inf. But in your implementation, the max reduces over all remaining dimensions (including the batch and num_heads dimension), so there is a single value. This means that iterative_inv(data[0,0]) is different from iterative_inv(data)[0,0] (you can verify this with torch.manual_seed(1); data = torch.randn(1, 2, 5, 5)). Is that intended?

Basically, shouldn't it be this instead:

def iterative_inv(mat, n_iter = 6):
    I = torch.eye(mat.size(-1), device = mat.device)
    K = mat
    a1 = torch.max(torch.sum(torch.abs(K), dim = -2, keepdim=True), dim=-1, keepdim=True).values
    a2 = torch.max(torch.sum(torch.abs(K), dim = -1, keepdim=True), dim=-1, keepdim=True).values
    V = 1 / (a1 * a2) * K.transpose(-1, -2)
    for _ in range(n_iter):
        KV = torch.matmul(K, V)
        V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
    return V
yyxiongzju commented 3 years ago

@sbodenstein, Thanks for pointing out this. We did not realize the difference between the current coefficient computation and the math formula in the paper. Like what I replied to you in the email, You are right about the coefficient of initialization point for the pseudoinverse approximation. Since the ||A_s||_inf = 1 because of the softmax function, the difference is a1 = torch.max(torch.sum(torch.abs(K), dim = -2) or a1 = torch.max(torch.sum(torch.abs(K), dim = -2, keepdim=True), dim=-1, keepdim=True).values. With the max reducing over all remaining dimensions, it will make a scale difference comparing to max only along the (num_landmarks, num_landmarks) axis and the current coefficient computation will be more conservative in this sense. Even though the current initialization coefficient does not affect the final pseudoinverse convergence and the performance (6 iterations are similar to 8 iterations on BookCorpse), it may lead to fewer iterations to achieve a reasonable pseudoinverse approximation by following exactly the math formulation. It can also resolve any initialization issue depending on batch size like you pointed out.

sbodenstein commented 3 years ago

Thanks! Can I make a PR to fix this, or do you want to keep it like it is?

mlpen commented 3 years ago

We updated the code and provided the option to compute the coefficient. Thanks again for pointing out the difference.

sbodenstein commented 3 years ago

Thanks!