huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
9.16k stars 1.08k forks source link

Return most probable tokens + logprobs #604

Closed Vinno97 closed 1 year ago

Vinno97 commented 1 year ago

Feature request

It would be great if the API could return a list of most probable tokens (along with their logprobs) for each step. This could be useful for many downstream tasks that require sampling of these.

Motivation

Many LLM-based applications benefit not just from having the most probable token, but also a list of output probabilities.

Your contribution

I'd be willing to contribute to this feature, but I'd love some guidance on that. Some initial reading of the code makes me think this won't require big changes changes in the code, though it'll touch many different parts of it. For that reason I'd want to be sure of the preferred approach before wasting both my and your time.

I guess we could add this in (all variations of) Model#generate_token and add the top n values from next_token_logits to this iterator:

https://github.com/huggingface/text-generation-inference/blob/b7327205a6f2f2c6349e75b8ea484e1e2823075a/server/text_generation_server/models/flash_causal_lm.py#L891-L902

And then also add it to the Generation?

https://github.com/huggingface/text-generation-inference/blob/b7327205a6f2f2c6349e75b8ea484e1e2823075a/server/text_generation_server/models/flash_causal_lm.py#L974-L982

Provided we make sure the input and output is passed on correctly everywhere else where it's touched, would this be the way to go?

I'm not a Rust programmer, but the code seems straightforward enough (famous last words?). Just add the parameter everywhere where the other params like top_k are also mentioned.


Regarding how to call this feature. OpenAI has an integer parameter called "logprobs" that does just this. However, this won't work since TGI already has another parameter with this name. Perhaps top_n_logprobs, top_k_logprobs, top_n_tokens, top_k_tokens?

njhill commented 1 year ago

@Vinno97 it's on my todo list to contribute a PR for this. We implemented it in our own fork of TGI (which is a bit behind on the latest changes in this repo, but also does some things differently and includes additional features like this).

You can see the relevant code here and for input tokens here. Note that we have more granular flags for choosing which specifically which response details you'd like and don't compute the ones that aren't requested.

As well as requesting top_n candidates with logprobs, you can independently request the rank for every chosen/input token (i.e. it's position in the tokens ordered by score at that step).

OlivierDehaene commented 1 year ago

@njhill, that would be great! What's the latency hit? That's my main concern especially in prefill as you end up doing a massive sort on the prompt_tokens x vocab size matrix.

Note that we have more granular flags for choosing which specifically which response details you'd like and don't compute the ones that aren't requested.

We have something similar now if I understood you correclty.

Vinno97 commented 1 year ago

I have added a draft PR with this feature. I merged @njhill's code and updated it to work with the newer. See my PR for its current limitations

njhill commented 1 year ago

What's the latency hit?

@OlivierDehaene I haven't really measured the latency impact yet, but I haven't noticed much. It only applies to requests that ask for it specifically and most of them don't. A max top_n value can be configured at deploy time, so this can be disabled in cases where the deployer doesn't want to allow it.

you end up doing a massive sort on the prompt_tokens x vocab size matrix.

Why do you need to sort the entire vocab? It's just a top_k operation right?

We have something similar now if I understood you correclty.

I think so but I'm not sure that it's as granular as what we had done, i.e. separate flags for requesting input tokens, logprobs, token ranks, top_n candidates, etc.

I have added a draft PR with this feature.

Thanks @Vinno97, I was on vac last week and still catching up but will try to take a look at your PR soon.

We could also look into vectorizing this in the batch dimension, as was done for the logits processors/warpers.

Vinno97 commented 1 year ago

I thought I responded before, but apparently I never finished my reaction. Anyways:

We could also look into vectorizing this in the batch dimension, as was done for the logits processors/warpers.

I thought about this as well when porting the code. But I first wanted to get initial code up and running, before I got to anything like that.

What's the latency hit?

Very significant. Though admittedly, I'm currently calling .item() and decoding every top token separately in a loop. Here's some "quick" benchmarking (I should've just profiled in the end). Done using google/flan-t5-small on 2 Nvidia L4 GPUs.

Benchmarking results | Parameter | Value | |--------------------|----------------------| | Model | google/flan-t5-small | | Sequence Length | 10 | | Decode Length | 8 | ## Top-n-tokens=0 > All `top_n_tokens` code is skipped | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|-----------|-----------|-----------|-----------|-----------|-----------| | Decode (token) | 1 | 17.18 ms | 16.67 ms | 19.11 ms | 17.15 ms | 17.67 ms | 17.47 ms | | | 2 | 18.37 ms | 17.83 ms | 18.89 ms | 18.33 ms | 18.83 ms | 18.76 ms | | | 4 | 19.48 ms | 18.68 ms | 24.45 ms | 19.33 ms | 20.08 ms | 19.13 ms | | | 8 | 23.09 ms | 22.27 ms | 27.70 ms | 22.73 ms | 25.83 ms | 22.60 ms | | | 16 | 25.71 ms | 24.98 ms | 29.98 ms | 25.57 ms | 26.36 ms | 25.81 ms | | | 32 | 33.14 ms | 32.16 ms | 38.42 ms | 33.01 ms | 33.72 ms | 38.42 ms | ### top_n_tokens=50 > Current PR (commit 789d809e515e506b64b6d7d48e1dfa21b57a2f8a) | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|------------|------------|------------|------------|------------|------------| | Decode (token) | 1 | 34.71 ms | 34.32 ms | 35.86 ms | 34.61 ms | 35.27 ms | 34.45 ms | | | 2 | 52.90 ms | 51.93 ms | 64.27 ms | 52.60 ms | 53.13 ms | 52.66 ms | | | 4 | 86.93 ms | 85.75 ms | 89.39 ms | 86.91 ms | 87.87 ms | 86.42 ms | | | 8 | 158.25 ms | 156.63 ms | 165.01 ms | 157.78 ms | 160.21 ms | 157.10 ms | | | 16 | 295.65 ms | 291.76 ms | 307.58 ms | 294.76 ms | 303.00 ms | 293.72 ms | | | 32 | 573.43 ms | 566.89 ms | 590.43 ms | 572.00 ms | 585.07 ms | 590.43 ms | ### top_n_tokens=50, no detokenization, no .item() > Remove the .item() and detokenization (replace by constant) | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|------------|------------|------------|------------|------------|------------| | Decode (token) | 1 | 29.31 ms | 28.77 ms | 29.96 ms | 29.34 ms | 29.96 ms | 29.20 ms | | | 2 | 41.85 ms | 41.45 ms | 43.70 ms | 41.75 ms | 43.70 ms | 41.90 ms | | | 4 | 66.42 ms | 65.01 ms | 70.81 ms | 65.97 ms | 68.70 ms | 70.81 ms | | | 8 | 115.40 ms | 114.39 ms | 118.62 ms | 115.03 ms | 118.62 ms | 116.77 ms | | | 16 | 210.20 ms | 208.52 ms | 220.53 ms | 209.51 ms | 220.53 ms | 213.09 ms | | | 32 | 402.04 ms | 398.08 ms | 421.48 ms | 400.45 ms | 411.07 ms | 421.48 ms | ### top_n_tokens=50, no loop (`top_tokens = [TopToken(...)] * len(top_n_indices)`) > Remove the entire `for token in top_n_indices` loop and just build one array with the same obj. | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|------------|------------|------------|------------|------------|------------| | Decode (token) | 1 | 20.67 ms | 20.17 ms | 21.24 ms | 20.65 ms | 21.24 ms | 20.78 ms | | | 2 | 24.74 ms | 23.77 ms | 25.34 ms | 24.78 ms | 25.34 ms | 23.77 ms | | | 4 | 30.25 ms | 29.63 ms | 31.14 ms | 30.21 ms | 31.14 ms | 29.93 ms | | | 8 | 45.24 ms | 44.30 ms | 48.16 ms | 44.68 ms | 48.16 ms | 48.10 ms | | | 16 | 69.38 ms | 68.73 ms | 70.45 ms | 69.25 ms | 70.35 ms | 70.45 ms | | | 32 | 120.27 ms | 119.59 ms | 122.06 ms | 120.12 ms | 121.92 ms | 122.06 ms | ⚠️ Only think that this does less is looping through the tensor dims and reducing object creation. Significant difference ### top-n-token=50, no loop, no sort > Same as above, but without the final sort | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|------------|------------|------------|------------|------------|------------| | Decode (token) | 1 | 18.01 ms | 17.62 ms | 18.88 ms | 17.99 ms | 18.53 ms | 18.88 ms | | | 2 | 19.45 ms | 19.27 ms | 19.80 ms | 19.39 ms | 19.80 ms | 19.49 ms | | | 4 | 21.97 ms | 21.32 ms | 22.77 ms | 21.98 ms | 22.77 ms | 22.23 ms | | | 8 | 28.33 ms | 27.12 ms | 31.97 ms | 28.24 ms | 31.97 ms | 27.80 ms | | | 16 | 35.89 ms | 35.09 ms | 36.49 ms | 35.91 ms | 36.49 ms | 36.26 ms | | | 32 | 53.12 ms | 52.07 ms | 55.53 ms | 52.99 ms | 54.33 ms | 55.53 ms | ⚠️ The sort indeed makes a huge difference ### top-n-token=50, no sort > Plain 789d809e515e506b64b6d7d48e1dfa21b57a2f8a without the sort | Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 | |----------------|------------|------------|------------|------------|------------|------------|------------| | Decode (token) | 1 | 25.16 ms | 24.65 ms | 25.77 ms | 25.17 ms | 25.77 ms | 24.95 ms | | | 2 | 34.44 ms | 33.51 ms | 41.22 ms | 34.24 ms | 41.22 ms | 34.14 ms | | | 4 | 50.08 ms | 49.17 ms | 53.41 ms | 49.82 ms | 53.41 ms | 49.83 ms | | | 8 | 83.52 ms | 82.58 ms | 86.72 ms | 83.27 ms | 86.72 ms | 82.90 ms | | | 16 | 146.44 ms | 144.31 ms | 155.13 ms | 146.00 ms | 155.13 ms | 147.14 ms | | | 32 | 275.56 ms | 271.85 ms | 286.07 ms | 274.26 ms | 285.19 ms | 286.07 ms |

I wasn't able to work on it last week due to some other priorities that came up. However, these tasks are now done and I should have time to take it up again this week.

I initially wanted to just optimize the current function and implement it for all models. However, these benchmarks make me think it may be smart to immediately look vectorizing over the batch.

njhill commented 1 year ago

Thanks @Vinno97, another difference in our fork is that we do all of the token decoding on the rust side, also we have the max allowed value of n set to 5... 50 seems pretty high but I guess it depends on the use cases.

Vinno97 commented 1 year ago

I purposefully set n very high, but maybe that was indeed a bit too high. For top-n-tokens=5, there is also a noticable hit, but less so.

Step Batch Size Average Lowest Highest p50 p90 p99
Decode (token) 1 18.83 ms 18.50 ms 19.49 ms 18.80 ms 19.49 ms 18.77 ms
2 21.37 ms 20.99 ms 21.79 ms 21.39 ms 21.79 ms 21.25 ms
4 25.36 ms 24.81 ms 29.09 ms 25.12 ms 29.09 ms 25.19 ms
8 34.54 ms 33.71 ms 36.45 ms 34.54 ms 36.45 ms 34.76 ms
16 48.57 ms 48.02 ms 49.16 ms 48.62 ms 49.16 ms 48.98 ms
32 78.97 ms 77.86 ms 83.66 ms 78.61 ms 80.64 ms 83.66 ms

I'll come back later (today?) with results of batch vectorization

Vinno97 commented 1 year ago

Batching update: I switched to EleutherAI/pythia-160m, as it uses FlashCausalLM, where batching is already more generally included. Though these changes should still work for the other models as well.

One pretty significant speed-up in the batched calculation, is that I return at max n amount of top_n_tokens, even when there are multiple tokens with equal probabilities. I don't know how often this would happen in real life, but this decision enables me to take a much easier sampling approach. I can directly use the values from torch.topk and don't have do another check to find any values equal to the smallest logprob. It also means I don't have to sort again, since topk can return sorted values.


Parameter Value
Model EleutherAI/pythia-160m
Sequence Length 10
Decode Length 8
Batch size 64
Description Decode (token) Average Lowest Highest p50 p90 p99
top-n-tokens=0 14.26 ms 13.07 ms 21.49 ms 13.54 ms 17.47 ms 21.49 ms
top-n-tokens=20, fully unbatched (789d809e515e506b64b6d7d48e1dfa21b57a2f8a) 417.62 ms 413.64 ms 421.49 ms 417.68 ms 421.19 ms 421.49 ms
top-n-tokens=20, batched calculation, unbatched decoding (3cc111daadb8bdacf5018e5ab9954f9e231dd986) 61.12 ms 60.13 ms 66.89 ms 60.50 ms 63.32 ms 66.89 ms
top-n-tokens=20, batched calculation, batched decoding per sequence (d45982114af1646e2171d4ff4e89ff7e070f68a8) 30.51 ms 29.60 ms 34.24 ms 29.96 ms 33.61 ms 34.24 ms
top-n-tokens=5, batched calculation, batched decoding per sequence (d45982114af1646e2171d4ff4e89ff7e070f68a8) 18.71 ms 18.21 ms 21.62 ms 18.53 ms 19.46 ms 21.62 ms
typical-p=0.5 (reference) 19.46 ms 18.70 ms 24.74 ms 19.06 ms 20.20 ms 24.74 ms
OlivierDehaene commented 1 year ago

Very cool thanks for the benchmarking results. I will review your PR tomorrow!

njhill commented 1 year ago

Yes, thanks @Vinno97 this looks great!

One pretty significant speed-up in the batched calculation, is that I return at max n amount of top_n_tokens, even when there are multiple tokens with equal probabilities. I don't know how often this would happen in real life

Unfortunately it's actually very common with 16 bit precision. Not sure how much it matters though. But say if n=3 and there are a bunch tied for second place, you would get two of these arbitrarily in positions 2 and 3.

Vinno97 commented 1 year ago

Unfortunately it's actually very common with 16 bit precision.

Just confirmed it with a quick experiment. I'll try some things to see how efficiently I can implement "n-ish-top-tokens". Then we'll also need to see which one of the implementations we would actually prefer.

@OlivierDehaene functionally, the PR is still WIP. Feedback on the way I'm implementing it would be very much appreciated, though. Especially the Rust changes were mostly a process of adding code in the right places until it worked.

HendrikStrobelt commented 1 year ago

While not committing any code or experiment, I want to thank you all for thinking about this. This feature is a major differentiator against closed text-in-text-out models. Especially important for XAI approaches. So, here comes some admiration :)

Vinno97 commented 1 year ago

@OlivierDehaene @njhill I've implemented the "top-n-ish-tokens" and optimized it as much as possible. It's not quite as fast as before, but is more reliable and actually scales better than the previous version with very large batch sizes.

Step Batch Size Average Lowest Highest p50 p90 p99
Prefill 64 33.32 ms 30.41 ms 50.16 ms 31.31 ms 45.20 ms 50.16 ms
Decode (token) 22.24 ms 21.10 ms 28.08 ms 21.68 ms 25.86 ms 28.08 ms
Decode (total) 155.67 ms 147.71 ms 196.55 ms 151.76 ms 181.02 ms 196.55 ms

A lot of time is actually spent copying a small tensor of the requested amount of top tokens to the GPU.

In the profile, this shows up in the nonzero call, though that's because this is a sync moment (the "plusplus" function is the new implementation). image

I don't immediately see a method to reduce this.

OlivierDehaene commented 1 year ago

You can add a top_n_tokens_tensor GPU tensor to the batch object and create it in Batch.from_pb. Then you can use the List[int] or torch.Tensor depending on the situation. You need to make sure that both are synced though.

Vinno97 commented 1 year ago

Sounds like a good solution. I'll try on Monday 👍

Vinno97 commented 1 year ago

It resolved a millisecond of latency. Happy that this part is a third faster, but the absolute majority of the added latency comes from decoding of tokens, which is beyond my control. I'm already using batch_decode, so not much to speed up there. There's one more think I want to try, which is batching detokenization across the entire batch. I'm not holding my breath for a big speed-up, however. Unless there's an absolute target latency, I think that is as good as it's going to get for now.

Further optimizations should perhaps come from shared decoding of the sequences, which tokenizers doesn't support to my knowledge.

Vinno97 commented 1 year ago

Cleaned up the code, added CLI flags, added support in the Python client and changed the PR to "ready to review".

For reference, here's the current performance. Note that as @njhill said before, 5 is a way more realistic top-n than 20 or 64.

Parameter Value
Model EleutherAI/pythia-160m
Sequence Length 10
Decode Length 8
Description Prefill Average Decode (token) Average Decode (total)
top-n-tokens=0 21.79 ms 13.54 ms 94.75 ms
typical-p=0.5 (reference) 28.79 ms 19.46 ms 136.19 ms
top-n-tokens=5 26.72 ms 17.81 ms 124.67 ms
top-n-tokens=20 30.67 ms 22.45 ms 157.12 ms
top-n-tokens=64 40.20 ms 34.53 ms 241.69 ms
Vinno97 commented 1 year ago

Closed by merging #617

jinsong-mao commented 4 months ago

Hi, @Vinno97 can use tell me how to dump the perfetto trace profile? so that we can visualize the internal much better. image

thanks

Vinno97 commented 4 months ago

How I got that trace?

Iirc, I just wrapped the forward pass in a PyTorch profiler context, described here: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html. I did have issues on my machine where I couldn't profile the GPU itself, but you can check https://pytorch.org/blog/accelerating-generative-ai/ for a good view on how to profile and optimize.

I hope that this is sufficiently helpful?