hao-ai-lab / LookaheadDecoding

[ICML 2024] Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
https://arxiv.org/abs/2402.02057
Apache License 2.0
1.11k stars 66 forks source link

How to generate the n-grams - which to keep, which to discard? #14

Open bobqianic opened 10 months ago

bobqianic commented 10 months ago

This is actually not my question, llama.cpp wants to implement this but encountered some problems.

https://github.com/ggerganov/llama.cpp/pull/4207

Viol2000 commented 10 months ago

Hi, thanks for your interest!

Sorry for not clearly presenting it in our blog and code. We will refactor the code for better readability.

We maintain a lookahead branch (a 2D window) with size (W , (N - 1)). It is in variable past_tokens and initiated here https://github.com/hao-ai-lab/LookaheadDecoding/blob/main/lade/decoding.py#L223C5-L223C17 Currently, we initialize the first level with the tokens randomly selected from the prompt. Initialized by the function set_token. Then we will fill out the #(N - 2) levels in the #(N - 2) warmup steps. Find discussion here: https://github.com/hao-ai-lab/LookaheadDecoding/issues/8 .

An example of the lookahead branch is in the figure below (N = 4 and W = 5)

There are two points: 1) The first level (orange) is one token less. 2) We need #(N - 2) warmup steps to gradually fill the lookahead branch. It is not very important. Currently, we gradually fill the lookahead branch, but you can also initialize the whole lookahead branch with random tokens for simplicity. It will still finally converge and give speedup.

For each step. We will generate one token from each window slot, as yellow ones in the figure below. Then we do two things:

1) collect N-grams as in the figure above (4-grams). 2) Drop the oldest level and use the newly generated tokens to maintain the size of the lookahead branch. We will drop orange tokens and green token with number 1. And we form a new lookahead branch with green tokens with number 2-5, all red tokens and all yellow tokens.

About your question, the n-grams collected is stored in a hash map with its starting token as the hash map key and a set of all following (n-1)-grams as value (variable token_map in our code). For each key, we maintain a set of sizes GUESS_SET_SIZE to avoid a too-large verification branch size. The logit of maintaining its size can be found here: https://github.com/hao-ai-lab/LookaheadDecoding/blob/main/lade/decoding.py#L357C1-L385C1 . Currently, when its size is larger than GUESS_SET_SIZE, we will discard the n-gram that was added first. You can also explore using other discard policies.

Please feel free to ask if you have other questions.

ggerganov commented 10 months ago

For a given starting token and a pool of GUESS_SET_SIZE(n-1)-grams, how do you select the G n-grams to use for verification? AFAICT, GUESS_SET_SIZE can be large - let's say 60. And G is recommended to be equal to W so, let's say 5. How do you choose 5 from the set of 60?

Viol2000 commented 10 months ago

@ggerganov Hi, thanks for your interest.

Sorry for my confusing expression. In my implementation, GUESS_SET_SIZE is actually G in the blog. By doing this, I can maintain the pool size starting with a given token when I add new n-grams into the pool. I just dropped the earliest token in that pool when its size was larger than G. Making the GUESS_SET_SIZE and G not the same is also OK; it can offer more selection policies. But for me, I just discard the earliest token in that pool and set GUESS_SET_SIZE = G.

It is OK to set G equal to W as a small number like 5 or 10. Or it can be slightly larger than W. I think It will be good enough to achieve speedups. We will face diminishing margins when you set a large G like 60. It is not recommended on person GPU.

tdene commented 10 months ago

@Viol2000 I remain confused about the outputs of the lookahead branch.

You say

And we form a new lookahead branch with green tokens with number 2-5, all red tokens and all yellow tokens.

What happens to the output tokens that correspond to input tokens of green 1-5 and orange 1-4? If they are not used in the next iteration of the lookahead algorithm, what are they used for? Are they used as speculation drafts alongside the cache or are they just not computed?

Viol2000 commented 10 months ago

@tdene Thanks for your interest! You raised some excellent points!

These tokens' outputs are discarded. However, I do not think it is a waste to compute them. Although these tokens' outputs are discarded, these tokens themselves are used to build yellow tokens with stronger local information. For example, yellow 7 takes orange1-4, green 5 and red 6 in attention and give output.

Another way to do this is to put these tokens in kv-cache. So we can remove row orange 1-4 and green 1-5 and save flops. Moreover, it seems llama.cpp's implementation of lookahead decoding uses this method(correct me if I was wrong). However, if you put them in kv-cache, they actually missed the information on newly coming tokens. For example, if you put orange 1 in kv-cache, it should have been built 2 steps before and remained unchanged since then. At that time, blue 0 is not obtained, so the kv-cache embedding of orange 1 have an incorrect token's information rather than blue 0. And I think it will reduce the local information and reduce the acceptance rate of your guess tokens.

ggerganov commented 10 months ago

Another way to do this is to put these tokens in kv-cache. So we can remove row orange 1-4 and green 1-5 and save flops. Moreover, it seems llama.cpp's implementation of lookahead decoding uses this method(correct me if I was wrong).

It is not the case in llama.cpp. In the case from the figure we submit a full batch containing:

None of the lookahead tokens are stored in the KV cache so they will "see" the blue token from the current iteration.

Viol2000 commented 10 months ago

Hi @ggerganov, thank you for providing clarification! I've noticed the following graph, and it appears to not include several rows. My initial assumption was that there might be an innovation in kv-caching past tokens. Nevertheless, I appreciate your substantial effort in implementing Lookahead Decoding in llama.cpp.

        // Example for W = 5, N = 4, G = 2:
        // (I = input, L = lookahead, V = verification)
        //
        // Batch:  0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20
        // T:        -2 -2 -2 -2 -1 -1 -1 -1 -1  0  0  0  0  0  0
        // Info:   I  L  L  L  L  L  L  L  L  L  L  L  L  L  L  V  V  V  V  V  V
        // Pos:    0  1  2  3  4  1  2  3  4  5  2  3  4  5  6  1  2  3  1  2  3   (+ n_past)
        // Logits: 1  0  0  0  0  0  0  0  0  0  1  1  1  1  1  1  1  1  1  1  1
        // ---------------------------------------------------------------------
        // Seq:    0
        //         1              1              1
        //         2  2              2              2
        //         3  3  3              3              3
        //         4  4  4  4              4              4
        //         5  5  5  5  5              5              5
        //         6                                            6  6  6
        //         7                                                     7  7  7
        // ---------------------------------------------------------------------
        //                                       |  |  |  |  |  |  |  |  |  |  |
        //                                       V  V  V  V  V  |  |  |  |  |  |
        //                                         j_tokens     |  |  |  |  |  |
        //                                                      V  V  V  V  V  V
        //                                                             id
ggerganov commented 10 months ago

This diagram represents the input batch containing 21 tokens. Later in the code, based on the sequence ids of the tokens in the batch, we construct the attention mask which is 21 x 21.

tdene commented 10 months ago

And we form a new lookahead branch with green tokens with number 2-5, all red tokens and all yellow tokens.

these tokens themselves are used to build yellow tokens with stronger local information. For example, yellow 7 takes orange1-4, green 5 and red 6 in attention and give output.

So the output of the red 5 line in the diagram (which pays attention to blue 0, orange 1, orange 2, orange 3, green 4, red 5), which you name "yellow 6", becomes the new red 5 in the next time-step's calculation?

And the output of the green 3 line in the diagram (which pays attention to blue 0, orange 1, orange 2, green 3), which you name "red 4", becomes the new green 3 in the next time-step's calculation?

And then the output of the orange 2 line in the diagram (which pays attention to blue 0, orange 1, orange 2), which you name "green 3", becomes the next orange 2 in the next time-step's calculation?

I was confused because it sounded, from that first quote, that you use the same "green tokens with number 2-5, all red tokens and all yellow tokens" that are shown in the diagrams as inputs. But "these tokens themselves are used to build yellow tokens with stronger local information", I'm now understanding that you don't use green tokens 2-5 from the diagram, you use the outputs of their lines?

hsm1997 commented 10 months ago

@tdene

looking into the code, I think:

So the output of the red 5 line in the diagram (which pays attention to blue 0, orange 1, orange 2, orange 3, green 4, red 5), which you name "yellow 6", becomes the new red 5 in the next time-step's calculation?

yes

And the output of the green 3 line in the diagram (which pays attention to blue 0, orange 1, orange 2, green 3), which you name "red 4", becomes the new green 3 in the next time-step's calculation?

The red in the diagram is generated by green, not in the current decode iteration (when green and red already exists), but in the previous iteration when red did not exist. If red does not exist, decode with (input_ids, orange, green) to get red; if red already exists, decode with (input_ids, orange, green, red) to get yellow, and use green to update orange, use red to update green, and use yellow to update red. and the output of green and orange in current decode iteration is discarded, as the author previously mentioned.

learning-chip commented 8 months ago

These tokens' outputs are discarded. However, I do not think it is a waste to compute them. Although these tokens' outputs are discarded, these tokens themselves are used to build yellow tokens with stronger local information. For example, yellow 7 takes orange1-4, green 5 and red 6 in attention and give output.

Hi @Viol2000 -- regarding this comment, if the green and orange rows of output are not used, why not just trim them out in the attention mask? This won't have any effect on the output yellow-3,4,5,6,7 tokens, and also won't affect the collected 4-grams.

lookahead_trim

Are the green and orange outputs somehow used to refine past states (y(t-2) and y(t-3))? If not, then the above mask should totally work fine. The information still flows from input {orange1-4, green 5, red 6} to the output yellow-7, using such remaining mask.

learning-chip commented 8 months ago

why not just trim them out in the attention mask?

I checked the related code:

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L445-L451

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/decoding.py#L274-L275

In higher-level call, the full outputs.logits are never directly used. Only its segments out_logits, inp_logits, guess_logits are used.

Viol2000 commented 8 months ago

These tokens' outputs are discarded. However, I do not think it is a waste to compute them. Although these tokens' outputs are discarded, these tokens themselves are used to build yellow tokens with stronger local information. For example, yellow 7 takes orange1-4, green 5 and red 6 in attention and give output.

Hi @Viol2000 -- regarding this comment, if the green and orange rows of output are not used, why not just trim them out in the attention mask? This won't have any effect on the output yellow-3,4,5,6,7 tokens, and also won't affect the collected 4-grams.

lookahead_trim

Are the green and orange outputs somehow used to refine past states (y(t-2) and y(t-3))? If not, then the above mask should totally work fine. The information still flows from input {orange1-4, green 5, red 6} to the output yellow-7, using such remaining mask.

@learning-chip Sorry for the late reply. I did not see your message. Even if the green and orange rows of output are not used, they are useful for generating yellow 3,4,5,6,7 tokens. In the current case, if the current step is t, the red tokens come from the t-1 step, green tokens come from the t-2 steps, and yellow tokens come from the t-3 steps. If you do not input these tokens, you are actually using the kv-cache of green and orange tokens, and the information of current input- deep blue 0- can not be obtained in the green and orange tokens. (they are in kv-cache and static) So they are actually out of date. In short, if you do not input green and orange rows, these tokens can not 'see' deep blue 0, and only red tokens can 'see' deep blue 0. If you input these tokens, in shallow layers, these tokens can obtain the information of deep blue 0 and thus affect the output of red tokens in deep layers.

I have some initial implementations that show that without green and orange rows as inputs, the generated n-grams have a lower quality to be accepted. I believe there are more trade-offs to do -- with lower computational costs and lower acceptance ratios or higher computational costs and higher acceptance ratios.

Viol2000 commented 8 months ago

why not just trim them out in the attention mask?

I checked the related code:

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L445-L451

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/decoding.py#L274-L275

In higher-level call, the full outputs.logits are never directly used. Only its segments out_logits, inp_logits, guess_logits are used.

  • out_logits is at position prefill_size - 1 -- is the one "must-be-correct" token?
  • guess_logits is at position -lguess: -- is the verification result? (still kept in the trimmed mask)
  • inp_logits is at position -len(past_tokens[fill_level])-lguess:-lguess -- correspond to the yellow-3,4,5,6,7 tokens in your figure? (still kept in the trimmed mask) (also what does inp stand for?)

My naming is bad here; I will keep refactoring the code in the following weeks.

learning-chip commented 8 months ago

I have some initial implementations that show that without green and orange rows as inputs, the generated n-grams have a lower quality to be accepted.

Hi @Viol2000 thanks for the explanation. For this part I totally understand. By "trim the mask" I mean shortening the y-axis (output dimension, or q dimension), not shortening the x-axis (input dimension, or kv dimension). The output red tokens can still see the full inputs.

Viol2000 commented 8 months ago

@learning-chip I think every input token corresponds to one row.(query) And each column corresponds to one input token or a kv-cache entry.(key) Is it correct? So the green and orange ones should be inputted or in kv-cache. If they are in kv-cache, the situation I explained above will happen. If they are input, we need to have these as rows. Is my understanding correct?

learning-chip commented 8 months ago

If they are input, we need to have these as rows.

This is only true if you restrict to using the LlamaAttention.forward() API, which is hard-coded to as self-attention. In that case the Q and K tokens has to be the same: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py#L386-L388

If directly using FlashAttention or PyTorch scaled_dot_product_attention API, the query and key lengths can be different. More like cross-attention here (not between two totally different sequences, but between a sequence and its segment), ref: https://github.com/Dao-AILab/flash-attention/blob/v2.1.2/flash_attn/modules/mha.py#L638-L642

Viol2000 commented 8 months ago

I know what you mean, but I do not think it is an implementation problem here. The problem is if you do not input orange and green ones, they will not be included in the hidden_states. The orange and green ones need to be stored in kv-cahe.

learning-chip commented 8 months ago

The problem is if you do not input orange and green ones, they will not be included in the hidden_states.

Now, hidden_states for Q is shortened to exclude orange and green positions, while the hidden_states for K&V remain unchanged (to keep orange and green positions). Should use two different variable names like hidden_states_q and hidden_states_kv. Using the same hidden_states for QKV is a restriction only if you use the original LlamaAttention.forward()

Viol2000 commented 8 months ago

So you mean we still input all orange and green tokens?

Viol2000 commented 8 months ago

If so, the orange and green tokens need to go through MLP but not participate in attention computation. Is my understanding correct?

learning-chip commented 8 months ago

So you mean we still input all orange and green tokens?

Yes, so K and V still have the complete input information. Just make Q shorter. Rewriting the LlamaAttention.forward() method will allow this.

Viol2000 commented 8 months ago

But I doubt if they will keep input information if they do not inputted to attention as queries. As far as I understand, attention's output is correspond to queries. If not put them as queries and just as keys, they will somehow 'skip' the attention and not include the information we want. It is an inspiring idea but seems strange?

learning-chip commented 8 months ago

If so, the orange and green tokens need to go through MLP but not participate in attention computation. Is my understanding correct?

Right, the MLP is done independently on each token (no cross-token information flow), and will stay as it. Jus t to save FLOPs for the attention computation.

learning-chip commented 8 months ago

If not put them as queries and just as keys, they will somehow 'skip' the attention and not include the information we want.

softmax(Q*K') is a pair-wise computation, if you don't need some output rows you don't need to have them in Q.

learning-chip commented 8 months ago

@Viol2000

The problem is if you do not input orange and green ones, they will not be included in the hidden_states.

I got your point and did some more experiments.

Shortening Q for the last decoding layer is perfectly fine. It has no impact on algorithm correctness, but just saves FLOPs for self.q_proj(hidden_states) * x and Q*K'. Such FLOP savings can be beneficial for batching, which requires more hardware FLOPs (or compute to memory ratio)

Shortening Q for other decoding layers will affect the inputs to their next layers. At the end this will affect the lookahead guess and thus the speed-up ratio. Testing with TinyLlama/TinyLlama-1.1B-Chat-v1.0 with increasing skip_layers (shortened Q for the last skip_layers layers):

  # experiment results: (skip_layers, compression_ratio)
            # 0: 1.84 (standard lookahead implemented by this repo)
            # 1: 1.84 (no affect)
            # 2: 1.77
            # 3: 1.86 (better)
            # 4: 1.68
            # 5: 1.78
            # 6: 1.71
            # 7: 1.68
            # 8: 1.78
            # 9: 1.66
            # 10: 1.59
            # 11: 1.66
            # 12: 1.6
            # ...
            # 20: 1.48
            # 21: 1.43
            # 22(all): 1.34

So the speed-up roughly (but not strictly) drops monotonously when more layers have shortened query/output dimension.

To try the above test, in LlamaDecoderLayer.forward, modify the code like:

# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
    ...
        )  # unchanged
used_tokens = len(past_tokens[fill_level]) + lguess # obtained from jforward_multilevel()
hidden_states[:,1:-used_tokens,:] = 0.0  # as if those output positions are not computed
hidden_states = residual + hidden_states  # just pass over inputs for those positions

# Fully Connected
# (unchanged)

And only branch to this modified decoder.forward() for the last several decoder layers: https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L234-L236

This modification does not actually save FLOPs. Just to see how the guess-accuracy will be like if self.self_attn is replaced by "cross attention" with a shortened query and a trimmed mask (https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1903633633)

Further, after obtaining the combined attention mask from j_prepare_decoder_attention_mask:

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L201-L203

you can also modify the unused mask rows to a whatever value:

used_tokens = len(past_tokens[fill_level]) + lguess # obtained from jforward_multilevel()
maskout_value = torch.finfo(attention_mask.dtype).min  # any value, doesn't matter
attention_mask[:,:,1:-used_tokens,:] = maskout_value

With the previous modification to LlamaDecoderLayer.forward, this modification to mask has no further impact on computation results. At least for the last decoder layer, it is safe to take out those mask rows & output positions.

Viol2000 commented 8 months ago

@learning-chip Thanks for your valuable discussion and experiments! I think the mask trimming does show the potential of saving FLOPS. And when we only prune a few layers, the step compression ratio does not drop. But I doubt reducing these flops will turn into speedups as we can not prune too many attention layers, and the attention computation only takes part of flops in the entire model. As the compression ratio does not perfectly reflect the overall latency, there may be a trade-off between ratio and flops. I guess it is worth more exploration.