mit-han-lab / streaming-llm

[ICLR 2024] Efficient Streaming Language Models with Attention Sinks
https://arxiv.org/abs/2309.17453
MIT License
6.38k stars 355 forks source link

question about re-computation #51

Closed ysanimals closed 9 months ago

ysanimals commented 9 months ago

Hello. After read paper and code, I have a question about the method about the Sliding Window w/Re-compuation, is there any code or paper about this method for reference? Thank you a lot.

Guangxuan-Xiao commented 9 months ago

Hi,

Thank you for your interest in our work!

The "Sliding Window with Recomputation" method is commonly used in language modeling evaluations to calculate perplexity given a certain number of context tokens. Huggingface provides a comprehensive tutorial on this topic. You can check it out here: https://huggingface.co/docs/transformers/perplexity.

Best, Guangxuan

rituraj2847 commented 9 months ago

Hi Guangxuan. Could you please explain how sliding windows with recomputation leads to O(TL^2) complexity. IMO, calculation of keys and values for the first L-1 tokens should be a function of weight matrix and dimension of KVs. How is this computation for each token taking O(L^2) time? I don't see a need for computing attention scores for initial L-1 tokens which would take O(L^2). Thank you.

Bhuvanesh09 commented 9 months ago

@rituraj2847 Your argument would have been valid had there been only one layer in the transformer. i.e. We need to only do a forward pass for each token with weight matrices of K and V [O(L)]. Unfortunately, the output of attention block of the first layer (attained through (Q @ K.T) @ V ) is actually the initial embedding of the second layer. Hence, each time you change the positional encoding of the tokens, it would take O(L^2) ops to do generation iteration of one new token since we need to calculate self-attention for all the tokens. Since the sequence length is T, the overall complexity becomes O(TL^2). Only for the last layer, we may avoid calculating self-attention for all the initial tokens. For the last layer we may only calculate attention of initial tokens applied to the query vector of the last token. I also had the same question as you and had to chalk out my understanding of transformers on a piece of paper in order to understand this.

rituraj2847 commented 9 months ago

Thank you. Understood :)

On Wed, 25 Oct 2023 at 5:56 PM, Bhuvanesh Sridharan < @.***> wrote:

@rituraj2847 https://github.com/rituraj2847 Your argument would have been valid had there been only one layer in the transformer. i.e. We need to only do a forward pass for each token with weight matrices of K and V [O(L)]. Unfortunately, the output of attention block of the first layer (attained through (Q @ K.T) @ V ) is actually the initial embedding of the second layer. Hence, each time you change the positional encoding of the tokens, it would take O(L^2) ops to do generation iteration of one new token since we need to calculate self-attention for all the tokens. Since the sequence length is T, the overall complexity becomes O(TL^2). Only for the last layer, we may avoid calculating self-attention for all the initial tokens. For the last layer we may only calculate attention of initial tokens applied to the query vector of the last token. I also had the same question as you and had to chalk out my understanding of transformers on a piece of paper in order to understand this.

— Reply to this email directly, view it on GitHub https://github.com/mit-han-lab/streaming-llm/issues/51#issuecomment-1779163186, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGLO3TH5KBWF55DZPVEI4KTYBEAQZAVCNFSM6AAAAAA6NAQ5TGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTONZZGE3DGMJYGY . You are receiving this because you were mentioned.Message ID: @.***>