mlpen / Nystromformer

Apache License 2.0
356 stars 41 forks source link

Incorrect initialization of pseudoinverse matrix calculation leads to convergence failure #6

Open kpot opened 3 years ago

kpot commented 3 years ago

Hi!

First, let me thank you all for the amazing paper! It's detailed enough that I've managed to successfully reproduce Nyströmformer in Tensorflow and so far I'm impressed how well it works in my humble tests.

Later I found this repository and while digging through the code was surprised to learn that the initialization you use within the iterative_inv method is quite different from what was suggested in "A New Iterative Method for Finding Approximate Inverses of Complex Matrices" by Razavi et al. and used in my implementation. Thinking "Oh, maybe my implementation could be further improved" I copied and tested your code only to discover it actually fails to converge for any significantly large matrix.

I've extracted iterative_inv into a simple test that runs the same set of increasingly large matrices through the method with varying initialization schemas until it converges and then reports the results. In addition to the original two initialization methods, I've added the one recommended by Razavi et al (called razavi in the report). And this is what I've got after running the test:

### Testing method 'original' ###
  Round 1 matrix size 2x2...SOLVED in 6 steps
  Round 2 matrix size 4x4...SOLVED in 9 steps
  Round 3 matrix size 8x8...ABORTED after 1000 iterations
  Round 4 matrix size 16x16...ABORTED after 4 iterations: Some of the values aren't finite anymore
  Round 5 matrix size 32x32...ABORTED after 3 iterations: Some of the values aren't finite anymore
  Round 6 matrix size 64x64...ABORTED after 3 iterations: Some of the values aren't finite anymore
  Round 7 matrix size 128x128...ABORTED after 3 iterations: Some of the values aren't finite anymore
### Testing method 'razavi' ###
  Round 1 matrix size 2x2...SOLVED in 6 steps
  Round 2 matrix size 4x4...SOLVED in 10 steps
  Round 3 matrix size 8x8...SOLVED in 10 steps
  Round 4 matrix size 16x16...SOLVED in 12 steps
  Round 5 matrix size 32x32...SOLVED in 16 steps
  Round 6 matrix size 64x64...SOLVED in 15 steps
  Round 7 matrix size 128x128...SOLVED in 17 steps
### Testing method 'other' ###
  Round 1 matrix size 2x2...SOLVED in 6 steps
  Round 2 matrix size 4x4...SOLVED in 9 steps
  Round 3 matrix size 8x8...ABORTED after 1000 iterations
  Round 4 matrix size 16x16...ABORTED after 4 iterations: Some of the values aren't finite anymore
  Round 5 matrix size 32x32...ABORTED after 3 iterations: Some of the values aren't finite anymore
  Round 6 matrix size 64x64...ABORTED after 3 iterations: Some of the values aren't finite anymore
  Round 7 matrix size 128x128...ABORTED after 3 iterations: Some of the values aren't finite anymore

Here is the test itself:

import math
import torch

# how many iterations we can spend on finding the inverse matrix
MAX_ITERATIONS = 1000
# maximum allowed value for matmul(A, A_inv) - I
MAX_ERROR = 1e-5

class NotFinite(Exception):
    def __init__(self, text, iteration):
        self.iteration = iteration
        super().__init__(text)

# Extracted from https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py
# and adjusted to run until either the error drops below MAX_ERROR
# or MAX_ITERATIONS is reached
def iterative_inv(mat, max_error, init_option):
    I = torch.eye(mat.size(-1), device=mat.device)
    K = mat

    # The entries of K are positive and ||K||_{\infty} = 1 due to softmax
    if init_option == "original":
        # This original implementation is more conservative to compute
        # coefficient of Z_0.
        V = 1 / torch.max(torch.sum(K, dim=-2)) * K.transpose(-1, -2)
    elif init_option == 'razavi':  # added by me
        # This initialization is proposed in the original article
        # "A New Iterative Method for Finding Approximate Inverses of
        # Complex Matrices" by Razavi et al (after the proof of Theorem 3).
        # https://www.hindawi.com/journals/aaa/2014/563787/
        V = (K.transpose(-2, -1)
             /
             (torch.norm(K, p=1, dim=(-2, -1), keepdim=True)
              * torch.norm(K, p=math.inf, dim=(-2, -1), keepdim=True)))
    else:
        # This is the exact coefficient computation, 1 / ||K||_1,
        # of initialization of Z_0, leading to faster convergence.
        V = 1 / torch.max(torch.sum(K, dim=-2), dim=-1).values[:, :, None,
                None] * K.transpose(-1, -2)

    KV = torch.matmul(K, V)
    i = 0
    while torch.max(torch.abs(KV - I)) > max_error and i < MAX_ITERATIONS:
        V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV,
            15 * I - torch.matmul(KV, 7 * I - KV)))
        KV = torch.matmul(K, V)
        i += 1
        if not torch.all(torch.isfinite(V)):
            raise NotFinite("Some of the values aren't finite anymore", i)
    return V, i

test_matrices = []
for i in range(1, 8):
    size = 2**i
    # Imitating matrix (batch_size, seq_length, num_q_landmarks, num_k_landmarks)
    M = torch.rand((1, 1, size, size))
    test_matrices.append((size, M))

for init_method in ('original', 'razavi', 'other'):
    print('## Testing method', repr(init_method), '##')
    for i, (size, M) in enumerate(test_matrices):
        print('  Round {round} matrix size {size}x{size}...'
              .format(round=i + 1, size=size),
              end='')
        try:
            M_inv, last_iter = iterative_inv(M, max_error=MAX_ERROR, init_option=init_method)
        except NotFinite as e:
            print("ABORTED after", e.iteration, "iterations:", str(e))
        else:
            if last_iter == MAX_ITERATIONS:
                print('ABORTED after', last_iter, 'iterations', flush=True)
            else:
                print('SOLVED in', last_iter, 'steps', flush=True)

I hope it will help to improve Nyströmformer.

P.S. Please, correct formula (14) in the paper. It's got the braces wrong :)

yyxiongzju commented 3 years ago

@kpot , Thanks for your interest. Sorry for the late reply due to several deadlines.

We did use the initialization method that Razavi et al suggested, Z_0 = A_S^T/(||A_S||_1||AS||{\infty}. As the input A_S is a softmax matrix in our case, A_S = softmax(\tilde{Q}\tilde{K}^T) and ||AS||{\infty} = 1, that is why we provided the implementation. The one you are trying right now is a general matrix. It makes the difference and leads to "convergence failure".

The formula (14) in Nyströmformer is correct. I wonder if you are referring to the old version?

kpot commented 3 years ago

@yyxiongzju Good point, I completely missed the fact (and the comment!) that iterative_inv works exclusively with post-softmax matrices. In that case your initialization is perfectly correct. Thanks for the explanation! Although I would have renamed the argument into softmax_mat to emphasize the fact that iterative_inv works only for certain cases.

However, can you explain why do you need the "original" initialization at all? I mean, look at its expression:

V = 1 / torch.max(torch.sum(K, dim = -2)) * K.transpose(-1, -2)

Assuming K is a bunch of different matrices, one might expect the need to calculate independent 1-norms for each of them. This "original" expression torch.max(torch.sum(K, dim = -2)), however, collapses all potential 1-norms into a scalar value, one for all matrices. While the "other" initialization does it right: torch.max(torch.sum(K, dim=-2), dim=-1), which leads to individual 1-norms for each matrix of the batch. I'm not surprised it works better. So it seems the "original" should be removed completely, since it is incorrect anyway. Why do you keep it?

In the version 2 of the paper (updated on arxiv Mar 5th 2021) I see this formula (notation was simplified): Z=0.25 Z(13I−AZ(15I−AZ)(7I−AZ)). The original paper by Razavi et al. as well as iterative_inv includes different expression: Z=0.25 Z(13I-AZ(15I-AZ(7I-AZ))). You need to move the closing brace right after 15I-AZ to the end of the formula.

kpot commented 3 years ago

@yyxiongzju On top of previous question. In the paper you say "For all our experiments, we need to run about 6 iterations in order to achieve a good approximation of the pseudoinverse". What do you mean by a good approximation? My best guess is you've calculated some kind of convergence expression (like || I - AZ ||_{F}) and then found a moment its rate of descend drops below a certain threshold. Please, can you elaborate on that?

I just don't feel comfortable setting n_iter = 6 and calling that a day. For some matrices the algorithm converges so fast it is enough to perform just 1-2 iterations. For others the solution may begin to diverge after 4 iterations and performing 6 of them actually worsens the outcome. In addition to that, the speed of convergence is inversely proportional to the size of the matrix, which in turn depends on the number of landmarks, which depends on the task at hands. I believe it's better to have a more general heuristic about what is a "good enough approximation". Something akin to the sqrt(d_k)-normalization in QKV attention.

yyxiongzju commented 3 years ago

@kpot , this difference was pointed out by sbodenstein in the closed issue. We keep the original one just for the reproduction. I will update it by setting "other" initialization by default.

The parenthesis issue was also pointed out by thomasw21 in the closed issue. I did not notice the arxiv one is also not updated. The updated one will be published by AAAI soon. You can also look at this version.

yyxiongzju commented 3 years ago

@kpot, Good point. We actually used the ||I - AZ||_{F} error to see how many iterations we need for a good approximation and we also compared with numpy.linalg.pinv.

You are right that the number of iterations depends on the number of landmarks for the task. Setting a reasonable threshold w.r.t. ||I - AZ||_{F} error will be better than a fixing iteration.

kpot commented 3 years ago

@yyxiongzju Thanks for the details! I've thrown together a notebook containing an alternative version of iterative_inv along with some of my experiments. Could you please take a look?

So far I see that 6 iterations is clearly not enough for the currently published algorithm to converge to anything even close to the pseudoinverse needed. It's just not fast enough (in terms of convergence speed). In fact, "in vivo" (on a real BERT model trained from scratch with the 6-step-Nystromformer) I personally see that 6 steps often is not enough for the solutions to even begin to converge. Often after 6 steps they just barely get started (having gone 1-5% of the distance), picking up speed after 8-10 steps and arriving to the result after 16-20 iterations. Yet the whole model is obviously still working.

So far I can conclude that:

  1. 6 steps is absolutely not enough for the current algorithm to converge to a good approximation, but not enough to start significantly diverging either, which saves the model from the sudden appearance of non-finite values. And because deep learning is famously resilient to noise/imprecise gradients etc, the model can still learn something. I believe that a better approximation can further improve the model's performance.
  2. The same paper by Razavi et al. contains a novel method (part 3) that is more computationally expensive but converges twice faster in my experience. It can indeed arrive at a decent approximation in 6 steps. Also, I quote:

    "Because the computation of a matrix norm (usually for dense matrices and for large sparse matrices) takes a reasonable time, therefore higher number of steps/iterations (which is the result of lower order methods) will be costlier than the lower number of steps/iterations".

yyxiongzju commented 3 years ago

@kpot, Thanks for sharing your notebook with me. To make the point more clear, we did NOT notice the difference using PyTorch implementation. That is why fixed the initialization issue after sbodenstein pointed out. The original one is kept just for the reproduction.

1, why we are saying 6 steps are enough? When we are running the experiments to show how many iterations are enough to achieve a good approximation, we implemented it by using numpy , which we use numpy.linalg.norm for each matrix. It avoids the mistake we made in our pytorch initialization implementation. You can see the approximation ablation studies in our supplement file. We also include one approximation experimental result here to show 6 steps can achieve a good approximation.

IMG_1259

2, why not using more computationally expensive one? We focus more on efficient computation method. It means we not only consider the steps, but also how many matrix multiplication in each steps. You can see the expensive one needs more matrix multiplications in total given 6 steps.

@kpot, Those are based on some experiments we did. We also did acknowledge sbodenstein for pointing out the initialization issue and all others to make this paper better by giving suggestions and comments in the acknowledgement section of this paper.

Your notebook looks good. You can push request to add this notebook. And I will thank you in the main page of this repo.

kpot commented 3 years ago

@yyxiongzju Thank you for your time, you did a good job explaining all this! I totally agree with you on that everything should be practical in such things. If a less expensive algorithm produces the same result (especially the same attention map), no point in arguing with the facts. This is really great.

You made me curios and I decided to take a pre-trained Transformer and draw some attention colormaps of my own. One of my models is a Nystromformer-based ALBERT pre-trained during several days on WikiText-103 to perplexity of about 37. I must admit I deviated a bit in its implementation: first, I've used locality-sensitive hashing to construct the landmarks (like in the Reformer model, I can share my implementation if you like). Second, it used the early version of the dynamic iterative_inv, that was mostly making about 9-16 iterations per call until full convergence to the same result as tf.linalg.pinv (my code is in Tensorflow).

This is what I've got plugging fixed-step iterative_inv into this model and comparing it against tf.linalg.pinv on a couple of examples from the same dataset:

attention_matrices_iterative_inv_vs_pinv2

attention_matrices_iterative_inv_vs_pinv3

To be honest, I'm not sure if it proves anything and can be taken seriously. Perhaps if the model was originally trained with the "6-step algorithm", it would be looking better for it. But in this case some attention heads are clearly messed up.

yyxiongzju commented 3 years ago

@kpot, Thanks for sharing these interesting results with us.

We also tried to compare the early initialization version and updated initialization scheme. We extracted the feature from a pre-trained standard Transformer. The early initialization takes ~12 steps to achieve a similar result as what updated scheme achieves with 6 steps (close to linalg.pinv). It looks consistent to what you have right now for your pretrained Nystromformer-based model.

Since you trained a Nystromformer-based model with dynamic iterative_inv, it makes the model learn self-attention using the dynamic iterative_inv (close to linalg.pinv). It makes sense that 6-steps will cause some attention heads messy based on a trained model with different schemes. Also you can see the attention heads are quite similar as using linalg.pinv if you use 12 steps. I think it may look better if you trained the model with the 6-steps algorithm. But using 6-steps updated initialization scheme may look more consistent comparing to linalg.pinv.

kpot commented 3 years ago

@yyxiongzju Good point! Indeed, if you train the network exclusively with the 6-step iterative_inv, since the whole thing is differentiable, we can expect that the network adjusts itself in a way that will essentially pre-condition the input matrix to facilitate its inversion in the given 6 steps.

However, the network can do only so much with the matrix, because it is constrained by the softmax operation. For example, it cannot set any element of the matrix to be equal to 1.2, or -0.1, or to make the sum of each row to be anything else but 1. Besides that, the network is also limited by the convergence order of the algorithm. Even in the best possible case this particular algorithm has 3rd order of convergence. BTW, this makes me wonder how this 18th order algorithm would work in this scenario. Also, if we could make the initialization somehow learnable...

Anyway, I've re-trained the whole model from scratch (not just copied some pre-trained weights, so I can't guarantee the same heads and layers learned the same tasks). I was keeping the model as canonical as I could, with the same 64 average-pool landmarks, the same strictly fixed 6-step algorithm, everything. Kept training until it got to the same performance metrics. And I've plotted the same attention graphs, only this time throwing in "The Real" QKV attention for comparison:

attention_matrices_mlm_avgpool_landmarks

I'm seriously impressed by how well the Nystromformer manages to approximate full attention! Let me thank you again for the work you've done! But I still see how the fixed 6-step algorithm messes the heads. Despite its being the only one the network even seen. Arguably about 50% of the heads can definitely benefit from at least 8 steps. Those strong vertical stripes, if I read them correctly, are equivalent to "all queries attending to the same few keys".

yyxiongzju commented 3 years ago

@kpot, I think it is an interesting question to explore if Nystromformer works beyond softmax matrix, e.g. f(QK^T) \approx f(Q\tilde{K}^T) x f(\tilde{Q}\tilde{K}^T)^+ x f(\tilde{Q}K^T). In this case, you can not assume the input for pseudoinverse computation is a softmax matrix in the implementation. But it is easy to change it to fit the pseudoinverse computation like what you did.

If you do not care about the computation efficiency, you can try the 18th order algorithm. Actually, we tried those algorithms for softmax matrix pseudoinverse with random matrices. It seems they can. But we did not try it for the whole model. I think it can work.

I do not understand what you mean making the initialization learnable. Do you mean the initialization can be adapted for different tasks, rather than fixed? It will be interesting if you can design an adaptive initialization scheme. Or do you mean if the landmarks can be learnable? For landmarks, you can add parameter matrix (linear) to learn it. It will be interesting to see if learnable landmarks can work out.

Thanks for sharing these results. It looks 8-step aligns better than 6-step. Did you see any performance improvement based on fixed 8-step? I tried to apply 6/8-step for vision task, t2t-vit. The performance is very close, around 78% top-1 accuracy on ImageNet with direct deployment (I did not train it at all).