Closed Vinno97 closed 1 year ago
@Vinno97 it's on my todo list to contribute a PR for this. We implemented it in our own fork of TGI (which is a bit behind on the latest changes in this repo, but also does some things differently and includes additional features like this).
You can see the relevant code here and for input tokens here. Note that we have more granular flags for choosing which specifically which response details you'd like and don't compute the ones that aren't requested.
As well as requesting top_n
candidates with logprobs, you can independently request the rank for every chosen/input token (i.e. it's position in the tokens ordered by score at that step).
@njhill, that would be great! What's the latency hit? That's my main concern especially in prefill as you end up doing a massive sort on the prompt_tokens x vocab size matrix.
Note that we have more granular flags for choosing which specifically which response details you'd like and don't compute the ones that aren't requested.
We have something similar now if I understood you correclty.
I have added a draft PR with this feature. I merged @njhill's code and updated it to work with the newer. See my PR for its current limitations
What's the latency hit?
@OlivierDehaene I haven't really measured the latency impact yet, but I haven't noticed much. It only applies to requests that ask for it specifically and most of them don't. A max top_n value can be configured at deploy time, so this can be disabled in cases where the deployer doesn't want to allow it.
you end up doing a massive sort on the prompt_tokens x vocab size matrix.
Why do you need to sort the entire vocab? It's just a top_k operation right?
We have something similar now if I understood you correclty.
I think so but I'm not sure that it's as granular as what we had done, i.e. separate flags for requesting input tokens, logprobs, token ranks, top_n candidates, etc.
I have added a draft PR with this feature.
Thanks @Vinno97, I was on vac last week and still catching up but will try to take a look at your PR soon.
We could also look into vectorizing this in the batch dimension, as was done for the logits processors/warpers.
I thought I responded before, but apparently I never finished my reaction. Anyways:
We could also look into vectorizing this in the batch dimension, as was done for the logits processors/warpers.
I thought about this as well when porting the code. But I first wanted to get initial code up and running, before I got to anything like that.
What's the latency hit?
Very significant. Though admittedly, I'm currently calling .item() and decoding every top token separately in a loop. Here's some "quick" benchmarking (I should've just profiled in the end). Done using google/flan-t5-small
on 2 Nvidia L4 GPUs.
I wasn't able to work on it last week due to some other priorities that came up. However, these tasks are now done and I should have time to take it up again this week.
I initially wanted to just optimize the current function and implement it for all models. However, these benchmarks make me think it may be smart to immediately look vectorizing over the batch.
Thanks @Vinno97, another difference in our fork is that we do all of the token decoding on the rust side, also we have the max allowed value of n set to 5... 50 seems pretty high but I guess it depends on the use cases.
I purposefully set n very high, but maybe that was indeed a bit too high. For top-n-tokens=5, there is also a noticable hit, but less so.
Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 |
---|---|---|---|---|---|---|---|
Decode (token) | 1 | 18.83 ms | 18.50 ms | 19.49 ms | 18.80 ms | 19.49 ms | 18.77 ms |
2 | 21.37 ms | 20.99 ms | 21.79 ms | 21.39 ms | 21.79 ms | 21.25 ms | |
4 | 25.36 ms | 24.81 ms | 29.09 ms | 25.12 ms | 29.09 ms | 25.19 ms | |
8 | 34.54 ms | 33.71 ms | 36.45 ms | 34.54 ms | 36.45 ms | 34.76 ms | |
16 | 48.57 ms | 48.02 ms | 49.16 ms | 48.62 ms | 49.16 ms | 48.98 ms | |
32 | 78.97 ms | 77.86 ms | 83.66 ms | 78.61 ms | 80.64 ms | 83.66 ms |
I'll come back later (today?) with results of batch vectorization
Batching update:
I switched to EleutherAI/pythia-160m
, as it uses FlashCausalLM, where batching is already more generally included. Though these changes should still work for the other models as well.
One pretty significant speed-up in the batched calculation, is that I return at max n
amount of top_n_tokens
, even when there are multiple tokens with equal probabilities. I don't know how often this would happen in real life, but this decision enables me to take a much easier sampling approach. I can directly use the values from torch.topk
and don't have do another check to find any values equal to the smallest logprob. It also means I don't have to sort again, since topk
can return sorted values.
Parameter | Value |
---|---|
Model | EleutherAI/pythia-160m |
Sequence Length | 10 |
Decode Length | 8 |
Batch size | 64 |
Description | Decode (token) Average | Lowest | Highest | p50 | p90 | p99 |
---|---|---|---|---|---|---|
top-n-tokens=0 | 14.26 ms | 13.07 ms | 21.49 ms | 13.54 ms | 17.47 ms | 21.49 ms |
top-n-tokens=20, fully unbatched (789d809e515e506b64b6d7d48e1dfa21b57a2f8a) | 417.62 ms | 413.64 ms | 421.49 ms | 417.68 ms | 421.19 ms | 421.49 ms |
top-n-tokens=20, batched calculation, unbatched decoding (3cc111daadb8bdacf5018e5ab9954f9e231dd986) | 61.12 ms | 60.13 ms | 66.89 ms | 60.50 ms | 63.32 ms | 66.89 ms |
top-n-tokens=20, batched calculation, batched decoding per sequence (d45982114af1646e2171d4ff4e89ff7e070f68a8) | 30.51 ms | 29.60 ms | 34.24 ms | 29.96 ms | 33.61 ms | 34.24 ms |
top-n-tokens=5, batched calculation, batched decoding per sequence (d45982114af1646e2171d4ff4e89ff7e070f68a8) | 18.71 ms | 18.21 ms | 21.62 ms | 18.53 ms | 19.46 ms | 21.62 ms |
typical-p=0.5 (reference) | 19.46 ms | 18.70 ms | 24.74 ms | 19.06 ms | 20.20 ms | 24.74 ms |
Very cool thanks for the benchmarking results. I will review your PR tomorrow!
Yes, thanks @Vinno97 this looks great!
One pretty significant speed-up in the batched calculation, is that I return at max n amount of
top_n_tokens
, even when there are multiple tokens with equal probabilities. I don't know how often this would happen in real life
Unfortunately it's actually very common with 16 bit precision. Not sure how much it matters though. But say if n=3 and there are a bunch tied for second place, you would get two of these arbitrarily in positions 2 and 3.
Unfortunately it's actually very common with 16 bit precision.
Just confirmed it with a quick experiment. I'll try some things to see how efficiently I can implement "n-ish-top-tokens". Then we'll also need to see which one of the implementations we would actually prefer.
@OlivierDehaene functionally, the PR is still WIP. Feedback on the way I'm implementing it would be very much appreciated, though. Especially the Rust changes were mostly a process of adding code in the right places until it worked.
While not committing any code or experiment, I want to thank you all for thinking about this. This feature is a major differentiator against closed text-in-text-out models. Especially important for XAI approaches. So, here comes some admiration :)
@OlivierDehaene @njhill I've implemented the "top-n-ish-tokens" and optimized it as much as possible. It's not quite as fast as before, but is more reliable and actually scales better than the previous version with very large batch sizes.
Step | Batch Size | Average | Lowest | Highest | p50 | p90 | p99 |
---|---|---|---|---|---|---|---|
Prefill | 64 | 33.32 ms | 30.41 ms | 50.16 ms | 31.31 ms | 45.20 ms | 50.16 ms |
Decode (token) | 22.24 ms | 21.10 ms | 28.08 ms | 21.68 ms | 25.86 ms | 28.08 ms | |
Decode (total) | 155.67 ms | 147.71 ms | 196.55 ms | 151.76 ms | 181.02 ms | 196.55 ms |
A lot of time is actually spent copying a small tensor of the requested amount of top tokens to the GPU.
In the profile, this shows up in the nonzero
call, though that's because this is a sync moment (the "plusplus" function is the new implementation).
I don't immediately see a method to reduce this.
You can add a top_n_tokens_tensor
GPU tensor to the batch object and create it in Batch.from_pb
. Then you can use the List[int]
or torch.Tensor
depending on the situation.
You need to make sure that both are synced though.
Sounds like a good solution. I'll try on Monday 👍
It resolved a millisecond of latency. Happy that this part is a third faster, but the absolute majority of the added latency comes from decoding of tokens, which is beyond my control. I'm already using batch_decode
, so not much to speed up there. There's one more think I want to try, which is batching detokenization across the entire batch. I'm not holding my breath for a big speed-up, however.
Unless there's an absolute target latency, I think that is as good as it's going to get for now.
Further optimizations should perhaps come from shared decoding of the sequences, which tokenizers
doesn't support to my knowledge.
Cleaned up the code, added CLI flags, added support in the Python client and changed the PR to "ready to review".
For reference, here's the current performance. Note that as @njhill said before, 5 is a way more realistic top-n than 20 or 64.
Parameter | Value |
---|---|
Model | EleutherAI/pythia-160m |
Sequence Length | 10 |
Decode Length | 8 |
Description | Prefill | Average Decode (token) | Average Decode (total) |
---|---|---|---|
top-n-tokens=0 | 21.79 ms | 13.54 ms | 94.75 ms |
typical-p=0.5 (reference) | 28.79 ms | 19.46 ms | 136.19 ms |
top-n-tokens=5 | 26.72 ms | 17.81 ms | 124.67 ms |
top-n-tokens=20 | 30.67 ms | 22.45 ms | 157.12 ms |
top-n-tokens=64 | 40.20 ms | 34.53 ms | 241.69 ms |
Closed by merging #617
Hi, @Vinno97 can use tell me how to dump the perfetto trace profile? so that we can visualize the internal much better.
thanks
How I got that trace?
Iirc, I just wrapped the forward pass in a PyTorch profiler context, described here: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html. I did have issues on my machine where I couldn't profile the GPU itself, but you can check https://pytorch.org/blog/accelerating-generative-ai/ for a good view on how to profile and optimize.
I hope that this is sufficiently helpful?
Feature request
It would be great if the API could return a list of most probable tokens (along with their logprobs) for each step. This could be useful for many downstream tasks that require sampling of these.
Motivation
Many LLM-based applications benefit not just from having the most probable token, but also a list of output probabilities.
lm-evaluation-harness
relies on having the output logits for its multiple-choice benchmarks. Having this feature will allow using this server to speed up benchmarking tasks.logit_bias
, you could usetop_k
, hope your prompting is sufficient such that the firstk
tokens are all your other classes and that thesoftmax
is only taken over these tokens. It is, however, impossible to know for sure.best-of
andtemperature
to hope you get all label's probabilities).Your contribution
I'd be willing to contribute to this feature, but I'd love some guidance on that. Some initial reading of the code makes me think this won't require big changes changes in the code, though it'll touch many different parts of it. For that reason I'd want to be sure of the preferred approach before wasting both my and your time.
I guess we could add this in (all variations of) Model#generate_token and add the top n values from
next_token_logits
to this iterator:https://github.com/huggingface/text-generation-inference/blob/b7327205a6f2f2c6349e75b8ea484e1e2823075a/server/text_generation_server/models/flash_causal_lm.py#L891-L902
And then also add it to the
Generation
?https://github.com/huggingface/text-generation-inference/blob/b7327205a6f2f2c6349e75b8ea484e1e2823075a/server/text_generation_server/models/flash_causal_lm.py#L974-L982
Provided we make sure the input and output is passed on correctly everywhere else where it's touched, would this be the way to go?
I'm not a Rust programmer, but the code seems straightforward enough (famous last words?). Just add the parameter everywhere where the other params like
top_k
are also mentioned.Regarding how to call this feature. OpenAI has an integer parameter called "logprobs" that does just this. However, this won't work since TGI already has another parameter with this name. Perhaps
top_n_logprobs
,top_k_logprobs
,top_n_tokens
,top_k_tokens
?