hao-ai-lab / LookaheadDecoding

[ICML 2024] Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
https://arxiv.org/abs/2402.02057
Apache License 2.0
1.11k stars 65 forks source link

Why using Jacobi decoding? What are the advantages besides the fact that Jacobi decoding can reduce some of the decode steps? #33

Closed ZipECHO closed 8 months ago

ZipECHO commented 9 months ago

What are the advantages besides the fact that Jacobi decoding can reduce some of the decode steps?Can it be replaced by normal greedy decoding? How much negative impact will this have on the efficiency of decoding?

shermansiu commented 9 months ago

I'm not one of the authors, but I'll answer.

The reason we would apply Lookahead Decoding in the first place is that it can reduce the number of decoding steps, at the cost of extra FLOPS. Faster decoding is always nice.

Yes, of course. But it might be slower than LADE with the optimal settings for your GPU (assuming you have enough spare FLOPS to get a speedup out of LADE).

As for the negative impact, it depends on your GPU, the LADE settings, and the number of FLOPS needed to run your LLM. In general, you'll get more of a benefit from using LADE if you have a powerful GPU (e.g. a A100) as opposed to a 3090 or a V100.

jivanph commented 8 months ago

The reason we would apply Lookahead Decoding in the first place is that it can reduce the number of decoding steps, at the cost of extra FLOPS. Faster decoding is always nice.

Do you happen to know of a simple way to measure the extra FLOPS you mention, to better understand the trade-off that Look Ahead offers, not only in terms of decoding steps but added computation?

shermansiu commented 8 months ago

The way to do this would be to find the optimal parameters experimentally for multiple models sizes for several GPUs and formulate a statistical scaling law for the optimal lookahead decoding parameters at each level.

I don't have the $$$ needed for this (for cloud GPUs), unfortunately.

jivanph commented 8 months ago

I meant something more simple, like computing Flops without and with LookAhead for some fixed parameters. Doesn't have to be the optimal parameters but any parameters that show a faster decoding.

shermansiu commented 8 months ago

Without lookahead, would something like this work? https://blog.eleuther.ai/transformer-math

Otherwise, no simple way exists yet. Someone will have to do the hard work to conduct the experiments and create a statistical model first before it becomes easy.

shermansiu commented 8 months ago

I think getting the estimated speedup given the model size, batch size, GPU VRAM, and GPU TFLOPS, etc. would be more useful though, as it's hard to have an intuitive understanding of the impact of FLOPS.

jivanph commented 8 months ago

What you mention surely would be more useful but also more demanding and even beyond the scope of a comparison for a fixed setup and fixed parameters.

shermansiu commented 8 months ago

Using the above transformer FLOPs equation, the total number of floating point operations would be roughly $2PD$ for a forward pass, where $P$ is the total number of non-embedding parameters and $D$ is the (inference) dataset size or more practically, the number of prompts*max output length.

You might think that it should be quadratic because of the nature of self-attention, but because of the KV-cache, it's actually linear. See https://kipp.ly/transformer-inference-arithmetic/#flops-counting for more details. Even for the exact formula that includes a $d{model}^2$ term, $d{model}$ is fixed for a specific model architecture, so...


Edit: No, it should definitely be quadratic. The derivation in https://kipp.ly/transformer-inference-arithmetic/#flops-counting omits this term, which is quite significant in practice. Basically, the formulas derive the number of FLOPs per token, and dismiss the $q\cdot k$ calculation as some constant factor with respect to $d_{model}$.

If we focus on the $\mathrm{softmax(qK^T/\sqrt{d{head}}})V$ term (in a single layer), that would take $O(\ell^2d{model})$ operations where $\ell$ is the sequence length (There is a single q due to KV-caching and the old k and v values are retrieved from the cache).

Assuming that we are looking at all heads at once, $qK^T$ takes $2\ell d{model}$ operations. The $\sqrt{d{head}}$ scaling factor then takes $\ell$ time. exp in torch leverages C++'s implementation of exp, which is constant time per scalar due to lookup tables and fancy bit math. Assuming we are using the stable max-trick softmax, then $\max$ takes $\ell$ time, subtraction takes $\ell$ time, $\exp$ takes $\ell$ time and summing the denominator takes $\ell$ time. Thus softmax takes roughly $4\ell + C$ time. The matrix multiplication with $V$ takes $2\ell d{model}$ time. Per-layer, this takes $4\ell d{model}+5\ell+C$ time.

Even though in a per-token basis, these terms can be ignored, for a certain completion that goes up to a max length of $L$, due to triangular numbers, we actually get $\frac{L^2+L}{2}(4d{model}+5)n{layer}+CLn{layer}$, which is non-negligible. This adjustment factor is omitted in most calculations. We'll approximate this as $2L^2d{model}n_{layer}$.

If we use the same Anthropic model as the one here and assume that we have a sequence length of 8192 ($d_{model}$ is also 8192), we get an adjustment factor of $2(8192)^2(8192)(64)=7.037E13$ or 70.3 quadrillion FLOPs. This is far larger than the 103B FLOPs we would have expected otherwise. Even if we use a more reasonable sequence length of 512, $2(512)^2(64)(8192)=274,877,906,944$ or 274B FLOPs.

In practice, not all sequences have the same exact length in a dataset but it goes to show that you can't neglect the quadratic nature of attention, even with a KV-cache.

Note that the coefficient of 2 has not been validated experimentally, so take it with a grain of salt.

shermansiu commented 8 months ago

For lookahead decoding, it seems that the lookahead branch uses $1+(N-1)W$ tokens for the lookahead branch and $(N-1)G$ tokens for verification, where

That's $1+(N-1)(W+G)$ tokens in total. Let $\alpha=(N-1)(W+G)$. Without lookahead decoding, we evaluate 1 query token/step, and with lookahead decoding, we evaluate $\alpha+1$ query tokens/step.

The default parameters are $N=7$, $W=20$, and $G=20$. Thus $\alpha=240$.

Edit: Using the adjusted equation above, if the prompt length is $\Lambda$, and the completion length is $L$, then it would take $2P(L-\Lambda) + 2(L^2-\Lambda^2)n{layer}d{model}$ FLOPs without lookahead decoding.

With lookahead decoding, this would then be $2P(L-\Lambda) + 2[(L+\alpha)^2-(\Lambda+\alpha)^2]n{layer}d{model}$.

Because $a^2-b^2=(a+b)(a-b)$, when $L-\Lambda=1$ or $\Lambda=L-1$, then the additive correction factor becomes linear (i.e. $4(L+\alpha-\frac12)n{layer}d{model}$.

Edit 2: I overlooked the fact that the lookahead decoding part can't be cached with the KV-cache. This means that we look at $\alpha+1$ queries at each time step instead of just 1. So the FLOPs is actually $2P(\alpha+1)(L-\Lambda) + 2(\alpha+1)[(L+\alpha)^2-(\Lambda+\alpha)^2]n{layer}d{model}$.

TLDR: Under the default settings, lookahead decoding increases the FLOPS so that it's as if the prompt had 240 extra tokens and the batch size was multiplied by 241. @Viol2000, does this look right?

shermansiu commented 8 months ago

Note that the above equations assume that inference is done at 16 bits of precision. That only applies to the memory equation for the KV-cache, which is not included here.

jivanph commented 8 months ago

One of the things that caught my eye when I first looked at LookAhead is that "[LookAhead decoding] Linearly decreases #decoding steps relative to log(FLOPs) used per decoding step".

Did you guys @Viol2000 run experiments where you counted FLOPs? I ask because I was wondering how the log comment came to be.

Viol2000 commented 8 months ago

We do not count FLOPs carefully; we just use #input tokens to approximate the per-step FLOPs. See this figure to find the relation between the log(FLOPs) and the compression ratio: https://lmsys.org/images/blog/laattention/match-scaling.png. I have a formulation in the paper to explain where the log operator comes from. Stay tuned!

Viol2000 commented 8 months ago

Hi @shermansiu , thanks for your careful counting of the per-step FLOPs. Transformers' FLOPs mainly come from the attention and the mlp layer. mlp layers' FLOPs are proportional to the # input tokens. And attention layers' FLOPs are also proportional to the # input tokens to some extent. So I just use # input tokens to approximate the per-step FLOPs (i.e., if you input 241 tokens, you roughly have 241x larger FLOPs.)

shermansiu commented 8 months ago

Actually, with a KV-cache, increasing the number of generated tokens alone by $\delta$ should not cause the number of FLOPs to increase by a factor of $\alpha$ times (multiplicatively). Normally, that would just cause the inference cost to increase by $2P\delta$ (plus the correction factor above), additively, according to the OpenAI scaling laws. i.e. New cost $\approx$ old cost + $2P\delta + 2(2L\delta+\delta^2)n{layer}d{model}$.

Adding $\delta$ tokens to the prompt alone means: New cost - old cost = $2P(L+\delta-\Lambda-\delta) + 2[(L+\delta)^2-(\Lambda+\delta)^2]n{layer}d{model} - \bigl[2P(L-\Lambda) + 2[L^2-\Lambda^2]n{layer}d{model}\bigr]=4\delta(L-\Lambda)n{layer}d{model}$.

The problem is that the OpenAI compute estimate discards the context-dependent terms, because "Since we primarily study models where $d{model} \gg n{ctx}/12$, we do not include context-dependent terms in our training compute estimate." Interestingly enough, they also do not include the cost for the self-attention computation in Table 1 of their scaling laws paper. (The different notation comes from the OpenAI scaling laws paper).

The reason we also have a multiplicative increase (on top of the additive increase) is because the lookahead decoding part is re-computed each Jacobi decoding step and can't be KV-cached.


As an aside, I think the FLOPs attention mask term in the OpenAI scaling laws paper is wrongly attributed... The $2n{layer}n{ctx}n{attn}$ cost is not the cost of masking, but of multiplying $\mathrm{softmax}(\mathrm{mask}(qK^T))$, a $1\times n{ctx}$ vector, by the $n{ctx}\times d{attn}$ $V$ vectors across all layers. Masking takes $n{layer}n{ctx}$ floating point operations (via an add).

shermansiu commented 8 months ago

Here's an oversimplified, but more intuitive way to think about it, using the dimensions of the $qK^T$/ $QK^T$ tensor. We are looking at inference per-token.

Situation Shape
No KV-cache $QK^T$: $(\ell+1)\times (\ell+1)$
KV-cache $qK^T$: $1\times (\ell+1)$
Prompt longer by $\delta$, KV-cache $qK^T$: $1\times (\ell + \delta+1)$
Speculative decoding, KV-cache $q K^T$: $(\alpha+1)\times (\ell+\alpha+1)$

For the full generation, the +1 terms in the width (i.e. the second dimension) get added up to $L-\Lambda$.

jivanph commented 8 months ago

Thank you so much @shermansiu This discussion enlightened me. I'm still trying to understand the trade-off, and from the FLOPs sides I have a clearer picture now. Do you know from the code if there is an easy way to extract the total steps taken during LookAhead runs? I see it expressed in the summary but I don't know if there's a way to store this statistic as an output as well.

shermansiu commented 8 months ago

You mean something like saving it to a file or using it elsewhere in your code?

The relevant line for the variable steps is here: https://github.com/hao-ai-lab/LookaheadDecoding/blob/b756db313419298d292a927c6dda950020ec1073/lade/decoding.py#L495

Use it however you wish.

jivanph commented 8 months ago

Thank you, I'm trying to store the steps variable and not just print it.

shermansiu commented 8 months ago

You'd need to modify the code for that.

The simplest (but hacky) way to obtain the value without having to modify a bunch of return calls is to have it call one of your own functions, say append_steps in line 496 of the aforementioned file.