Closed fakerybakery closed 5 months ago
Hi there, I'm not a maintainer of GaLore, but the problem here lies in increasing the sequence length. Taken from this paper.
One of the big downsides of transformer models is that as you increase the sequence length, the required memory costs increase at a quadratic rate - this is due to the attention mechanism needing to compare every token in the sequence with every other token in the sequence in a way that produces a number per each comparison, and all of these numbers then have to be stored in a matrix. This means if you double the sequence length, you would need four times the memory since you need to hold four times the values in a matrix as you did before, if you multiply it by four, you need 16 times the memory, and so on.
Recently and thankfully, techniques such as flash attention have gotten this down to a linear scale - if you double the sequence length, you "only" need double the memory for this. But even with flash attention, when you increase the sequence length from 512 to 8192, you would need sixteen times as much memory as before in order to store the intermediate attention "scores" (those numbers I talked about earlier) between tokens.
Compared to that, shearing off two billion parameters won't be enough to be able to get a model of that sequence length into your 24 GB GPU, unfortunately. GaLore can only do so much in this regard, but even it has to deal with this memory consumption problem as a result of the attention mechanism, since it needs to account for every number in that matrix during the adjusting of the parameters in the model (done through a method called backpropagation). This is a consequence of the self-attention mechanism found in transformers, and it's a tough challenge to deal with.
GaLore's examples for pre-training LLaMA-7B on a 24 GB GPU has the model using a maximum sequence length of 2048 by default. For higher sequence lengths at roughly the same parameter counts, you would need to get a hold of multiple GPUs with more VRAM, such as A40s/L40s or A100s/H100s.
Hope this helps!
Thank you so much for explaining! Makes sense now.
Hi, thanks for releasing GaLore! I'm running out of memory whenever I use a sequence length longer than 512, even if I use a smaller model. I can train a 7B model w/ a 512 sequence length on 24G VRAM, but I can't train a 5B model w/ a 8192 sequence length. Thanks!