jlamprou / Infini-Attention

Efficient Infinite Context Transformers with Infini-attention Pytorch Implementation + QwenMoE Implementation + Training Script + 1M context keypass retrieval
https://arxiv.org/abs/2404.07143
58 stars 5 forks source link

Have you tried to segment the hidden states within the attention class? #7

Closed ZackZikaiXiao closed 4 months ago

ZackZikaiXiao commented 4 months ago

Have you tried to segment the hidden states within the attention class? I tried but it's so unstable. The value memory gate easily reaches NAN.

jlamprou commented 4 months ago

@ZackZikaiXiao If you carefully read the README, I explain why I don't do the segmentation in-class. The paper follows the same segmentation logic behind Transformer-XL, Memformers etc. The complexity of in-class segmentation is O(S^2) which doesn't match the complexity described by the paper O(S).

ZackZikaiXiao commented 4 months ago

Yeah, the author segments the input in the training loop and I have read the code of Transformer-XL. I'm still curious about the potential performance of segmenting hidden states as its concise code (only replaces the attention class). Do you think segmenting in the attention still has acceptable perplexity, not considering complexity?

jlamprou commented 4 months ago

@ZackZikaiXiao I will reiterate, read the README carefully, I have tested both ways of doing it. The whole point of the paper is having a fixed-bounded memory while using huge global context. Doing the segmentation in-class consumes about the same memory as global SDPA attention, so you can't reach huge context lengths like 1M tokens because it's impossible to fit it in any VRAM. There is really no point in segmenting in-class , memory-wise it's the same as global SDPA attention and performance-wise it's worse. You can't achieve the same performance with global softmax (which is non-linear) with linear kernels, at least with the existing research.

ZackZikaiXiao commented 4 months ago

Yes, having “fixed-bounded memory“, “O(S) complexity“ and "the code of Transformer-XL" provides clear evidence for segmenting in the training loop for the paper. I'm not very clear about what global softmax, non-linear, and linear kernels mean—what is their relationship with segmenting? If it's possible to perform internal segmentation, combined with PEFT and some memory structures (like ring, tree) or DNC and NTMs as mentioned in the README, what do you think about the feasibility of this approach?

jlamprou commented 4 months ago

@ZackZikaiXiao When I say "You can't achieve the same performance with global softmax (which is non-linear) with linear kernels, at least with the existing research.", I'm trying to explain to you that if you do the segmentation in-class you have zero advantages, because you don't save memory and you lose performance since normal SDPA attention is just better. It's possible to do it in-class, there other implementations of the paper on GitHub that do it that way, but there is no point to it. This is not a better attention scheme, it's a series of math tricks to avoid keeping the whole global attention in the memory. We have a memory structure already in-class that is essentially what the paper calls "compressed memory". Let's leave the DNC and NTM suggestions out of this cause this conversation is gonna get even more complicated.🙃

ZackZikaiXiao commented 4 months ago

I'm starting to understand what you mean. If segmenting is within the attention, we still need to concatenate all fragments when the attention function returns, and this concatenated matrix is still as large as with SDPA, so it doesn't reduce the VRAM. By segmenting during the training loop, only one set of compressive memory needs to be maintained. Thank you very much for your response!

jlamprou commented 4 months ago

@ZackZikaiXiao Glad to be of help. I suggest Yannic's Kilcher comprehensive explanation of the paper , it's worth your time.