hao-ai-lab / LookaheadDecoding

[ICML 2024] Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
https://arxiv.org/abs/2402.02057
Apache License 2.0
1.11k stars 65 forks source link

Questions on combined attention mask structure for Jacobi iteration #44

Open learning-chip opened 8 months ago

learning-chip commented 8 months ago

I have some questions about the structure of custom mask for lookahead and verify branches as described in the blog.

Related code

The combined_attention_mask created by j_prepare_decoder_attention_mask(): https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L201-L203

Such attention mask is then sent to the decode layer, in order to compute all branches in a single model.forward call:

https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/models/llama.py#L234-L236

1. Token dependency in lookahead branches

For the upper-left corner (blue-0 and orange-1,2,3,4):

crop_jacobi

I can understand that it is the Jacobi iteration (screenshot from this paper):

jacobi_formula

where each token in the current guess is updated concurrently, based on the current values. The input is a 5-token guess (in this figure), and the output is the updated 5-token guess. The input-output dependency is a normal triangular (casual) mask.

What I don't fully understand, is the green and red branches:

jacobi_history

For example:

So why does each output token depend more on cross-iteration tokens, not in-iteration tokens? For example:

This is only about the model.forward computation part, not yet touching the N-gram cache management logic (that is in the greedy_search() level, above the model.forward() level). Thus this part should fall within the Jacobi decoding framework -- which shouldn't have cross iteration dependencies? (step-t state is computed from step t-1, but not t-2 or earlier)

2. Past tokens (KV cache) and ways to call fused operators

The blog's mask omits the past token history. The actual, full attention mask sent to LlamaAttention.forward() should look like this?

left_padded_mask

The middle yellow block has KV-cache available (past_key_values). The left padding block is the optional left-padding if used in static batching settings, to pad to the longest prompt length. The right block (shown in the original blog) deals with the newly-constructed queries (concats all branches), and no KV cache available.

So what would be the proper way to call fused FlashAttention operator for such a mask pattern? The triton version of FlashAttention supports custom attention bias -- setting bias to -inf according to the mask pattern should have the desired mask effect? Has anyone checked the performance gain?

learning-chip commented 8 months ago

For the simplest Jacobi decoding mask (like below), xformers support such a LowerTriangularFromBottomRightMask shape

jacobi_only

(BTW, the BlockDiagonalMask can also help with static batching without padding)

learning-chip commented 8 months ago

A further question, crucial for inference deployment, is how the merged lookahead computation can make use of paged/blocked KV cache. It needs a paged attention kernel that allows seqlen_q > 1, in order to take the concatenated inputs (lookahead + verify) in one kernel call.

However, vLLM's paged attention kernel seems limited to seqlen_q = 1 (ref https://github.com/Dao-AILab/flash-attention/issues/660#issuecomment-1803113706). In the test it is only compared to single-query attention.

This new flash_attn_with_blocked_kvcache kernel (https://github.com/vllm-project/vllm/issues/1880#issuecomment-1882962731) might work?

shermansiu commented 8 months ago

Once again, not one of the authors, but here's my answer:

  1. So why does each output token depend more on cross-iteration tokens, not in-iteration tokens? Because that's what would happen if you were to do Jacobi decoding sequentially instead of in parallel. Blue+Orange = Regular decoding, Blue+Orange+Green = Jacobi at time step + 1, Blue+Orange+Red = Jacobi at time step + 2.

I think the part you need to understand is that $N-1$ Jacobi decoding steps are done in parallel in the lookahead branch.

  1. You can't pass a custom mask to the Flash-Attention kernel without modifying its code, because the flash attention code creates the causal mask on the fly (i.e. it's created in-code, not passed as a parameter).

  2. LowerTriangularFromBottomRightMask is a convenience function for that specific purpose. Actually, it would be simpler to just use generate the mask specifically for lookahead decoding instead.

  3. That's fine... Everything below the first row is discarded in the next iteration and never makes it into the KV-cache anyways.

learning-chip commented 8 months ago

Thanks for the explanation but I think some points are not correct...

  1. Because that's what would happen if you were to do Jacobi decoding sequentially instead of in parallel. Blue+Orange = Regular decoding

No, Blue 0+Orange 1234 can be viewed as either:

The two views are equivalent. If the 5 tokens are perfect guess, it also means that the Jacobi iteration has converges to true solution.

Such computation is not the regular autoregressive decoding. The token state y(t) is computed concurrently (in sequence dimension) from the previous state y(t-1), not sequentially token-by-token.

mask

Still, my question is -- why green-5 must depends on {green-5, orange-4,3,2,1, blue-0}, but not {green-5,4,3,2,1, blue-0}, or {green-5,4,3,2, orange-1, blue0}, or other combinations from green/orange states (as long as the position spans over 543210)?

2. You can't pass a custom mask to the Flash-Attention kernel without modifying its code

I know, thus I mentioned the Triton version that takes additive mask... The CUDA version doesn't support mask input.

3. LowerTriangularFromBottomRightMask is a convenience function for that specific purpose. Actually, it would be simpler to just use generate the mask specifically for lookahead decoding instead.

Yes, my question is exactly -- when I have the combined_attention_mask generated from lookahead, how to find a fused operator that support such mask pattern. The current LADE code does not call fused attention kernel, and this is the performance bottleneck.

4. That's fine... Everything below the first row is discarded in the next iteration and never makes it into the KV-cache anyway

Yes, thanks for confirmation. I think this is why lookahead decoding is able to use more FLOPs than normal decoding, as the "query" length is longer and the malmul shape is more "square".

shermansiu commented 8 months ago
  1. (a) Actually, the lookahead branch and verification branches are labeled in the diagram you posted. The part you highlighted is part of the lookahead branch.

With iterative lookahead decoding, we do:

Refer to Figure 4 from the blog for details.

With lookahead decoding, you do

That's what I meant by in parallel.

The following answers from the lead author may also help. https://github.com/hao-ai-lab/LookaheadDecoding/issues/28#issuecomment-1847269260 https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518

I never claimed that it was like regular autoregressive decoding. The causal mask used in autoregressive decoding is quite different from that for lookahead decoding. Both Jacobi decoding (with guess tokens) and lookahead decoding are different from regular autoregressive decoding.


Edit: Ah, I noticed that I claimed earlier that blue+orange was part of the regular decoding, as opposed to the first Jacobi step. Whoops. Only blue is part of the regular decoding.


(b): Blue-orange= Step 1: Parallel decode Blue-Orange-Green= Step 2: Parallel decode Blue-Orange-Red= Step 3: Parallel decode

You can check that the tokens that are paid attention to match those from sequential lookahead decoding without verification.

I know, thus I mentioned the Triton version that takes additive mask... The CUDA version doesn't support mask input.

  1. Yeah, that should work. Recently, the LADE code was updated for transformers 4.36, where Llama now uses LlamaSdpaAttention, but I haven't looked at that PR too closely. It should call torch.nn.functional.scaled_dot_product_attention, which is an implementation of Flash Attention. It's FA1 in PyTorch version 2.0-2.1 and will be FA2 in the upcoming 2.2 release (https://github.com/pytorch/pytorch/pull/105602#issuecomment-1755453126)

I wouldn't mind helping to get LADE working with FA2, but FA support is already on the roadmap (https://github.com/hao-ai-lab/LookaheadDecoding/issues/13#issue-2009960949) (well, we have FA1 but not FA2 support yet) and the blog mentioned that CUDA kernels are coming. I don't really want to work on something that's already in progress, as it feels like wasted effort, but I'm not sure what their current progress on their FA-compatible CUDA kernels is like.

  1. I agree. The authors said that the kernels are on the way, so fingers crossed that they come out soon.

  2. Yep. Technically, the matmul shape is just as square in both cases, though it may not look like it. With the KV cache, regular autoregressive decoding just takes the top row and trims off the lookahead decoding+ verification branch parts on the right.

Even without a KV cache for autoregressive decoding in a single pass, multiply-add operations are still done over masked parts of the tensors.

shermansiu commented 8 months ago

It should also be noted that whether LADE can use more FLOPs to achieve a speedup depends on how powerful the GPU is. The authors only tested on A100 GPUs.

Hi [at]yhyu13 . You can check the table 1 in our blog. We require large extra flops to predict tokens. When the GPU is weak or the model is larger, we need to reduce this cost (and also, we will predict fewer tokens), or it will bring a slowdown.

yhyu13 noticed a slowdown using the default parameters when running a 13B model on a V100. By adjusting the parameters, there was a speedup, but it was small.

Joao Gante from the Huggingface transformers team also noticed a 33% slowdown with the default parameters for a 7B LLM on a 3090 (https://github.com/huggingface/transformers/issues/27649#issuecomment-1824621466). It took some manual hyperparameter adjustments to get a 25% speedup.

learning-chip commented 8 months ago

The following answers from the lead author may also help. #28 (comment) #14 (comment)

Thanks for the pointer! The question here https://github.com/hao-ai-lab/LookaheadDecoding/issues/28#issuecomment-1846924463 is exactly my question. The author commented:

If you use red 5 as the previous token of red 6, I think it does not make much sense as the red 6 has no relationship with red5, and it may not generate a meaningful n-gram.

Indeed, here red-6 attends to {green-5, orange-4,3,2,1, blue-0}. Even if it shouldn't attend to red-5, it could also attends to any of the following choices:

Those are all valid cross-iteration tokens that spans over position 5,4,3,2,1,0. Is any one fundamentally better?

learning-chip commented 8 months ago

The author comment https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 is a very good illustration. I do see the need of spanning-over-iterations in the "n-gram collection phase" -- otherwise, if staying within the same iteration (same color, all red / all green / all orange), then there is no creation of new "n-grams", and the algorithm falls back to the original Jacobi decoding paper

However, the "n-gram collection phase" is done out side of the attention layer computation, right? I can see the entire lookahead decoding algorithm as two-phases:

For phase (2), I totally understand why red-6 is grouped with {green-5, orange-4} but not grouped with {red-5, red-4} -- we need the cross-iteration "trajectory" here. But the attention mask pattern is used in phase (1), and even if restricting to in-iteration tokens (red-6 attends to red-6,5,3,2), the model.forward computation is still performing meaningful Jacobi iterations to improve the guess.

xinlong-yang commented 8 months ago

The author comment #14 (comment) is a very good illustration. I do see the need of spanning-over-iterations in the "n-gram collection phase" -- otherwise, if staying within the same iteration (same color, all red / all green / all orange), then there is no creation of new "n-grams", and the algorithm falls back to the original Jacobi decoding paper

However, the "n-gram collection phase" is done out side of the attention layer computation, right? I can see the entire lookahead decoding algorithm as two-phases:

  • (1) "Transformer computation phase", or "parallel decoding phase", or "Jacobi iteration phase", code is jforward_multilevel & LlamaModeljforward: build specialized attention mask (blog figure 5), concat multi-branch inputs, send to Llama model.forward, obtain Jacobi updates and verification results from different segments of concat-ed model output.
  • (2) "N-gram pool management phase" or "token map look-up phase", code is jacobi_greedy_search_multilevel(): based on phase (1)'s output segments, collect cross-iteration n-grams, and also look-up from n-gram cache to refine the speculative guess tokens.

For phase (2), I totally understand why red-6 is grouped with {green-5, orange-4} but not grouped with {red-5, red-4} -- we need the cross-iteration "trajectory" here. But the attention mask pattern is used in phase (1), and even if restricting to in-iteration tokens (red-6 attends to red-6,5,3,2), the model.forward computation is still performing meaningful Jacobi iterations to improve the guess.

hello, bro, I'm also confused about the mask design, have you understand this problem? I wonder if you can help to make it clear hhhh, thanks

shermansiu commented 8 months ago

Those are all valid cross-iteration tokens that spans over position 5,4,3,2,1,0. Is any one fundamentally better?

red-6 attending to {green-5, orange-4,3,2,1, blue-0} is the only one that corresponds to sequential Lookahead decoding, which has some intuition derived from Jacobi decoding.

As for the other options, you could try them, but then it would no longer be Lookahead decoding, but a derivative method. As for whether it's better or not, my guess is that the difference wouldn't be huge, but as always, someone would need to run experiments to assess the performance difference across a variety of prompts and models. Basically, we don't know until we try.

It's also important to remember how the orange, green, and red tokens are chosen when being passed to the model. With lookahead decoding, the orange tokens are randomly selected from the vocabulary and the green and red tokens are taken from previous passes through the model using lookahead decoding. This affects how we interpret the method.

shermansiu commented 8 months ago

hello, bro, I'm also confused about the mask design, have you understand this problem? I wonder if you can help to make it clear hhhh, thanks

What exactly is unclear to you?

learning-chip commented 8 months ago

red-6 attending to {green-5, orange-4,3,2,1, blue-0} is the only one that corresponds to sequential Lookahead decoding, which has some intuition derived from Jacobi decoding.

What's the definition of "sequential lookahead decoding", and how exactly is it different from the original Jacobi decoding? I thought that everything within jforward_multilevel() is the original Jacobi decoding computation (just concurrently over multiple branches) , while the extra lookahead stuffs are within jacobi_greedy_search_multilevel(), which is one level higher in the code. Maybe my understanding was wrong? I need a paper to clarify the exact lookahead decoding algorithm here😅

From the lookahead branch section of the blog:

In the figure, the blue token labeled 0 is the current input. The tokens in orange, green, and red were generated in previous Jacobi iterations at steps t-3, t-2, t-1, respectively

At the current step we conduct one Jacobi iteration to generate new tokens for all 5 positions, using the trajectory formed by the previous 3 steps.

The "Jacobi iteration", if we use the original definition in the Jacobi decoding paper, is computed within a single time step (i.e. only using information from step t-1, so all red tokens; or only using information from step t-2, so all green tokens...). So the word "Jacobi iteration" is the blog must have a modified meaning...

learning-chip commented 8 months ago

With lookahead decoding, the orange tokens are randomly selected from the vocabulary

That's the random initialization for Jacobi iteration, right?

the green and red tokens are taken from previous passes through the model using lookahead decoding.

Is "using lookahead decoding" here equivalent to "using Jacobi decoding" (i.e. the same algorithm as described in the Santilli 2023 paper)?

My reference is those sentences from the blog:

While lookahead decoding performs parallel decoding using Jacobi iterations for future tokens, it also concurrently verifies promising n-grams from the cache.

The tokens in orange, green, and red were generated in previous Jacobi iterations at steps t-3, t-2, t-1, respectively

At the current step we conduct one Jacobi iteration to generate new tokens for all 5 positions, using the trajectory formed by the previous 3 steps.

I thought "Jacobi iteration" should be defined as:

jacobi_formula

However, if defined in this way, the dependent token is restrictedly within step t-1, there is no way to have "cross-iteration" / "cross-color" dependencies in the attention mask pattern. So the word "Jacobi iteration" used in the blog must have somewhat a changed meaning, and this causes confusion😅

xinlong-yang commented 8 months ago

With lookahead decoding, the orange tokens are randomly selected from the vocabulary

That's the random initialization for Jacobi iteration, right?

the green and red tokens are taken from previous passes through the model using lookahead decoding.

Is "using lookahead decoding" here equivalent to "using Jacobi decoding" (i.e. the same algorithm as described in the Santilli 2023 paper)?

My reference is those sentences from the blog:

While lookahead decoding performs parallel decoding using Jacobi iterations for future tokens, it also concurrently verifies promising n-grams from the cache.

The tokens in orange, green, and red were generated in previous Jacobi iterations at steps t-3, t-2, t-1, respectively

At the current step we conduct one Jacobi iteration to generate new tokens for all 5 positions, using the trajectory formed by the previous 3 steps.

I thought "Jacobi iteration" should be defined as:

jacobi_formula

However, if defined in this way, the dependent token is restrictedly within step t-1, there is no way to have "cross-iteration" / "cross-color" dependencies in the attention mask pattern. So the word "Jacobi iteration" used in the blog must have somewhat a changed meaning, and this causes confusion😅

yes, I think the procedure in blog's Figure 4 is clear, but I can't make the connections between this procedure and Figure 5's mask, I am also confused why we need to attend tokens that are in step t-3, step t-2.

shermansiu commented 8 months ago

What's the definition of "sequential lookahead decoding", and how exactly is it different from the original Jacobi decoding?

"Lookahead decoding takes advantage of Jacobi decoding's ability by collecting and caching n-grams generated from Jacobi iteration trajectories."

That's the main difference. Jacobi iteration checks for the longest matching prefix from the previous iteration/the current input, whereas sequential lookahead decoding looks for N-gram matches generated from the previous Jacobi iteration trajectories.

shermansiu commented 8 months ago

That's the random initialization for Jacobi iteration, right?

Yes.

Is "using lookahead decoding" here equivalent to "using Jacobi decoding" (i.e. the same algorithm as described in the Santilli 2023 paper)?

As mentioned in the previous comment, the biggest difference is how verification is done. i.e. With N-grams instead of the longest matching prefix.

However, if defined in this way, the dependent token is restrictedly within step t-1, there is no way to have "cross-iteration" / "cross-color" dependencies in the attention mask pattern. So the word "Jacobi iteration" used in the blog must have somewhat a changed meaning, and this causes confusion😅

Not really? That just shows why red-5 must depend on green-4 and not orange-4 or red-4 and similarly, why red-4 must depend on green-3 and not orange-3 or red-3. Previous iterations would be subsumed into the term on the right.

shermansiu commented 8 months ago

yes, I think the procedure in blog's Figure 4 is clear, but I can't make the connections between this procedure and Figure 5's mask, I am also confused why we need to attend tokens that are in step t-3, step t-2.

The step t-3, step t-2, etc. tokens are the green tokens in the input. You'll start seeing green tokens in the input in Figure 4 in Steps 2 and later.

xinlong-yang commented 8 months ago

yes, I think the procedure in blog's Figure 4 is clear, but I can't make the connections between this procedure and Figure 5's mask, I am also confused why we need to attend tokens that are in step t-3, step t-2.

The step t-3, step t-2, etc. tokens are the green tokens in the input. You'll start seeing green tokens in the input in Figure 4 in Steps 2 and later.

Yes, I can see that. Could you please provide further help to answer the following questions? (1) As for verification branch, at current step, two light blue 1,2,3 are candidate n-grams, and they are verified parallelly through using normal causal mask, right? I think this part is easy to understand. (2) As for lookahead branch, I'm confused about the 'parallel decoding step', shall we also imply a normal causal mask? or we just design a special mask described in the blog to construct n-grams? Many thanks for your patience, I'm a new guy in LLM field, look forward to your reply!

shermansiu commented 8 months ago

(1) As for verification branch, at current step, two light blue 1,2,3 are candidate n-grams, and they are verified parallelly through using normal causal mask, right? I think this part is easy to understand.

Yeah, pretty much, even though it may not look like it at a first glance (i.e. the mask isn't lower triangular).

Note that using a normal lower triangular mask for the verification branch will not give you parallel verification, but I think you understand that.

(2) As for lookahead branch, I'm confused about the 'parallel decoding step', shall we also imply a normal causal mask? or we just design a special mask described in the blog to construct n-grams? Many thanks for your patience, I'm a new guy in LLM field, look forward to your reply!

For lookahead decoding, it's better to use the mask described in the blog. It's a special mask.

learning-chip commented 8 months ago

Not really? That just shows why red-5 must depend on green-4 and not orange-4 or red-4 and similarly, why red-4 must depend on green-3 and not orange-3 or red-3.

That Jacobi iteration formula indicates that red-5 must depend on {green-4,3,2,1} (i.e. y(t-1) state computed from y(t-2) state), but not {green-4, orange-3,2,1} as specified by the lookahead mask (i.e. y(t-1) state computed from both y(t-2) and y(t-3) states) -- the later cannot be the Jacobi iteration formula, but rather a different multi-step iterative scheme.

I am not arguing that we should not use such multi-step iterative scheme. I can conduct numerical experiments (testing different mask designs like https://github.com/hao-ai-lab/LookaheadDecoding/issues/44#issuecomment-1891651152) to see whether this improves convergence compared to single-step formulation. I just think that any modified schemes should not be called "Jacobi iteration" anymore, to avoid confusing the reader... In analogy, for ODE solvers, multi-step methods have their own names.

Note: the above discussion has nothing to do with n-gram collection yet -- I am only talking about the LlamaModeljforward() phase that builds mask and applies parallel decoding (in llama.py file), but not yet the jacobi_greedy_search_multilevel() phase (in decoding.py file that constructs token_map of n-grams). The n-gram collection phase clearly goes across steps (t-3, t-2, t-1), no question. But the parallel decoding phase, if is strictly a Jacobi method (by textbook definition), has no way to look over (t-3, t-2, t-1) at a single iteration. Of course we can instead call it "a modified multi-step Jacobi-like parallel decoding", then there will be no confusion... The confusion was from that the blog keeps using the term "Jacobi iteration" without re-defining it https://github.com/hao-ai-lab/LookaheadDecoding/issues/44#issuecomment-1894710029

shermansiu commented 8 months ago

That Jacobi iteration formula indicates that red-5 must depend on {green-4,3,2,1} (i.e. y(t-1) state computed from y(t-2) state), but not {green-4, orange-3,2,1} as specified by the lookahead mask (i.e. y(t-1) state computed from both y(t-2) and y(t-3) states) -- the later cannot be the Jacobi iteration formula, but rather a different multi-step iterative scheme.

It looks at several sequences in parallel because lookahead decoding focuses more on the left hand side of the parallel decoding. Please refer to Figure 3/4 from the blog and pay attention to the border between the green (accepted) and yellow (guess) tokens. i.e. In step n, where n $\gg$ 1, for example, the last few tokens in the lookahead decoding mask (which would be yellow/guess tokens) are not included at all. In fact, red-6 (for example), would be the left-most yellow (guess) token for the "red" step. You're right in that the output tokens depend on the previous guess tokens. But lookahead decoding doesn't really translate that into the mask. Beyond the first step of the sequential lookahead decoding/Jacobi step, all guess tokens (except the first) are discarded and not even included in the mask.

Now, you might ask, why keep several orange tokens then? Because I just said that the green-orange progression comes from looking at the accepted tokens from previous steps? The answer is because we are doing several lookahead steps in parallel. i.e. We pretend that several orange guess tokens have been accepted already even though they actually haven't. Thus, in the equations you are referencing, those orange tokens would be indexed as $\le i-1$. Recall that $i$ is incremented by 1 in each step of Jacobi decoding, and we are running several Jacobi steps in parallel in lookahead decoding.

As you noted, there is a difference in the methodology. But I think it's between sequential lookahead decoding and parallel lookahead decoding (i.e. full lookahead decoding) rather than between Jacobi decoding and sequential lookahead decoding.

learning-chip commented 8 months ago

The answer is because we are doing several lookahead steps in parallel. i.e. We pretend that several orange guess tokens have been accepted already even though they actually haven't.

This is a good explanation! I get the idea that this gives an extra degree of parallelism along the "lookahead step" dimension (N), in addition to the Jacobi window dimension (W in blog, or c in Jacobi formula). Also ref https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1844205330

But I think it's between sequential lookahead decoding and parallel lookahead decoding (i.e. full lookahead decoding) rather than between Jacobi decoding and sequential lookahead decoding.

Now I do get the motivation of multi-step iteration formula, but I will not call the method in the blog "Jacobi iteration" anymore😅In analogy to ODE solver, two-step formula is different from one-step Euler, despite being more accurate.

shermansiu commented 8 months ago

Sounds good! 🤗