Closed naed90 closed 6 months ago
Wow, this is so good. And so detailed. Amazing work!
I had a stupid question. What makes the other two changes not ready for merging? Would I face any issues if I merge all the commits together and run that version?
Hey @nivibilla and thanks for commenting!
The other 2 commits can certainly be merged, the code is just a bit ugly and currently doesn’t yet support beam-search. Will get around to it soon to clean it up a bit and then open a PR for that too :)
You can certainly merge it to try out and see the faster throughput!
Great thank you! I will test it out tomorrow and let you know
Hey @naed90 . By any chance are these optimisations only for A100s? I tested it on an A10G. And there was no improvement in performance for a 7B open llama.
My fork is here https://github.com/nivibilla/vllm
Hey @naed90 . By any chance are these optimisations only for A100s? I tested it on an A10G. And there was no improvement in performance for a 7B open llama.
My fork is here https://github.com/nivibilla/vllm
It should work independent of the GPU.
Whats your exact setup + the command line you’re running to benchmark? I’ll try to replicate it and take a look at the profiling of the program
Hi @naed90 I have tested vllm-fast single_query_cached_kv_attention_kernel with vllm main on A100 80G.
python3 benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model lmsys/vicuna-13b-v1.3 --num-prompts=1000
vLLM | qps | 1st | 2nd | 3rd | 4th |
---|---|---|---|---|---|
main | requests/s | 3.72 | 3.72 | 3.73 | 3.72 |
main | tokens/s | 1779.35 | 1777.35 | 1784.08 | 1779.35 |
opt | requests/s | 3.90 | 3.90 | 3.92 | 3.91 |
opt | tokens/s | 1865.76 | 1865.41 | 1874.88 | 1870.59 |
It shows that approximately 5% improvement in throughput.
Hi @naed90 I have tested vllm-fast single_query_cached_kv_attention_kernel with vllm main on A100 80G.
python3 benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model lmsys/vicuna-13b-v1.3 --num-prompts=1000
vLLM qps 1st 2nd 3rd 4th main requests/s 3.72 3.72 3.73 3.72 main tokens/s 1779.35 1777.35 1784.08 1779.35 opt requests/s 3.90 3.90 3.92 3.91 opt tokens/s 1865.76 1865.41 1874.88 1870.59 It shows that approximately 5% improvement in throughput.
Thanks for trying it out!
Makes sense, that first commit accounts for a 10% improvement in the LLaMA13B case. Not surprising that it gives 5% in Vicuna.
Got a chance to try it out also with the other 2 commits?
Hi @naed90
I have updated the results. Cheers.
python3 benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model lmsys/vicuna-13b-v1.3 --num-prompts=1000
vLLM | qps | 1st | 2nd | 3rd | 4th |
---|---|---|---|---|---|
main | requests/s | 3.72 | 3.72 | 3.73 | 3.72 |
main | tokens/s | 1779.35 | 1777.35 | 1784.08 | 1779.35 |
opt attention | requests/s | 3.90 | 3.90 | 3.92 | 3.91 |
opt attention | tokens/s | 1865.76 | 1865.41 | 1874.88 | 1870.59 |
opt sampler | requests/s | 4.59 | 4.59 | 4.59 | 4.60 |
opt sampler | tokens/s | 2195.60 | 2196.15 | 2194.36 | 2199.64 |
It shows that opt attention approximately 5% improvement in throughput and opt sampler approximately 23% improvement in throughput.
cc @WoosukKwon
Hi @naed90
I have updated the results. Cheers.
python3 benchmarks/benchmark_throughput.py --backend vllm --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json --model lmsys/vicuna-13b-v1.3 --num-prompts=1000
vLLM qps 1st 2nd 3rd 4th main requests/s 3.72 3.72 3.73 3.72 main tokens/s 1779.35 1777.35 1784.08 1779.35 opt attention requests/s 3.90 3.90 3.92 3.91 opt attention tokens/s 1865.76 1865.41 1874.88 1870.59 opt sampler requests/s 4.59 4.59 4.59 4.60 opt sampler tokens/s 2195.60 2196.15 2194.36 2199.64 It shows that opt attention approximately 5% improvement in throughput and opt sampler approximately 23% improvement in throughput.
cc @WoosukKwon
Nice!! It's possible to try them together (these improvements can be stacked) -- they're on 2 branches now to make the pull requests easier, but it's possible to create a branch with both (and then you'd probably get roughly 28% improvement?)
Hi @naed90
vLLM | qps | 1st | 2nd | 3rd | 4th |
---|---|---|---|---|---|
optimized attention and sampler | requests/s | 4.87 | 4.88 | 4.86 | 4.85 |
optimized attention and sampler | tokens/s | 2328.98 | 2332.31 | 2325.95 | 2317.19 |
In summary, by leveraging the power of attention and sampler optimizations, vicuna 13b achieves a remarkable 30% increase in throughput compared to the main version.
@naed90
I am using a EC2 G5 instance with a single GPU for a 7b OpenLlama model. The dataset is a list of sentiment analysis prompts. There must have been something wrong in my testing as everyone else seems to have got improvements. Let me re-do the testing and get back
Also good to note that I am doing it inside a notebook on databricks
Hi @naed90
vLLM qps 1st 2nd 3rd 4th optimized attention and sampler requests/s 4.87 4.88 4.86 4.85 optimized attention and sampler tokens/s 2328.98 2332.31 2325.95 2317.19 In summary, by leveraging the power of attention and sampler optimizations, vicuna 13b achieves a remarkable 30% increase in throughput compared to the main version.
Thanks for confirming and testing vicuna!
@naed90
I am using a EC2 G5 instance with a single GPU for a 7b OpenLlama model. The dataset is a list of sentiment analysis prompts. There must have been something wrong in my testing as everyone else seems to have got improvements. Let me re-do the testing and get back
Also good to note that I am doing it inside a notebook on databricks
Sounds good, looking forward.
Yeah, I still don't see the improvements.
Main
my fork
@naed90 Because its on a single gpu, the 7b overflows to the cpu, and that I am only requesting 5 tokens. Im not able to see the difference?
@naed90 Because its on a single gpu, the 7b overflows to the cpu, and that I am only requesting 5 tokens. Im not able to see the difference?
Spillage to the CPU can certainly be a problem -- how much GPU memory do you have and what are your configs for vllm? I.e., can you send the output of running nvidia-smi
and the configs you have in vllm engine args?
Spillage to the CPU can certainly be a problem -- how much GPU memory do you have and what are your configs for vllm? I.e., can you send the output of running
nvidia-smi
and the configs you have in vllm engine args?
GPU
LLM Args
Using tokeniser slow because of openllama
Hi @naed90
vLLM qps 1st 2nd 3rd 4th optimized attention and sampler requests/s 4.87 4.88 4.86 4.85 optimized attention and sampler tokens/s 2328.98 2332.31 2325.95 2317.19 In summary, by leveraging the power of attention and sampler optimizations, vicuna 13b achieves a remarkable 30% increase in throughput compared to the main version.
Hi, @zhyncs, great works! Just curious about what is the difference betwenn 1st/2nd/3rd/4th?Does it mean different test cases?
BTW, I wonder will the throughput drop if you use tensor parallel in multiple gpus env?
Hi @irasin
Yes, conducting multiple rounds of testing can help avoid inaccurate results from a single round. In our testing scenario and hardware configuration, vicuna 13b does not require tensor parallelism. If you are interested, you can test it yourself. Thank you.
Hi @irasin
Yes, conducting multiple rounds of testing can help avoid inaccurate results from a single round. In our testing scenario and hardware configuration, vicuna 13b does not require tensor parallelism. If you are interested, you can test it yourself. Thank you.
Thanks @zhyncs
Just FYI, I test vicuna7b model with main branch on nvidia A10 gpu. The result is | test case | throughput |
---|---|---|
single A10 without tp | Throughput: 1.83 requests/s, 874.36 tokens/s | |
4 A10 with tp_size =4 | Throughput: 2.94 requests/s, 1407.26 tokens/s |
Also, it's intresenting that throughput grows up with using tensor parallel.
Hi @irasin
Thank you for your testing and providing the data. When we consider trying tensor parallelism, it usually means that the model is too large to fit on a single device. Because if the model can fit on a single device, using multiple devices and enabling tensor parallelism would not increase the overall throughput in proportion to the number of devices, which would be cost-inefficient. Thank you.
hi, @zhyncs, thanks for your explanation. I got what you said about cost-inefficient. In fact, my calculations were wrong. If I use dp=4 with 4 single A10, the total throughput should be 1.83 * 4 = 7.32 reqs/s, which is much higher than 2.94 reqs/s of tp=4 case.
Do you think layer-group-wise kv-cache recycling can improve the throughput further?
Do you think layer-group-wise kv-cache recycling can improve the throughput further?
@wejoncy Interesting 🤔 you mean off-loading some of the kvcache to CPU RAM after we pass specific layers in the forward propagation? Or do you mean something different and I misunderstood?
Some thoughts on offloading parts to CPU: 1) The CPU-GPU bandwidth limit would probably limit us to not being able to grow the kvcache by too much, without this bandwidth limit becoming the new bottleneck. We could maybe grow the cache by a few tens of % before this transfer becomes the bottleneck. Didn’t try it though, so I may be wrong. This is similar to what vllm does with swap pages (tbh in most my experiments, I didn’t really see a big impact from the swap pages, but maybe it helps on other benchmarks). 2) I am unsure if growing the batch size would actually have a dramatic effect. Unless I am missing something, growing the kv cache would only give us a larger batch size, right? I tried limiting the kv cache size to 35GB and then comparing it with 70GB. Yes, there was a speedup, but it was in the low tens of % points. So I am unsure if putting in a lot of work to grow the kv cache size by say 30-40% by offloading to CPU RAM (before the communication becomes a bottleneck) would end up with an overall improvement of more than say 10-20% roughly.
What do you think? Note that I did not experiment with these specific aspects enough, so there is a chance that I am mistaken on some stuff.
Another interesting open direction: compression. What if we could compress the kv cache? A sequence of length say 400 tokens would imply in LLaMA13B a kv cache usage of seq_len num_layers (key_size_in_bytes + value_size_in_bytes) = 400 40 (10240 + 10240) = ~327MB. This is a bit ironic, since the same 400 token-long sequence is originally of size 400 * log(token space) which is just under 0.25KB. It would potentially be interesting to compress the kvcache when it’s stored in GPU RAM. As the bottleneck of the attention kernel is memory (and the compute sits almost completely unutilized), then if we manage to compress the kvcache by say 2x and decompress inside the kernel in registers/L1/L2 cache, it could really improve runtime.
Sure, the 327MB kv cache isn’t just a function of the original string but also a function of the model weights, which might imply that we need the compression to be a function of the entire kv cache.
Creating a compression alg here which will both be efficient distributed-wise, compression-wise, etc might be rather hard
Thanks for your great work, @naed90. I love this proposal and want to see how far it can go further.
Here I share my results of LLaMA running on A100-80GB for whom might be interested. Throughput (tokens/s) | 13B (tp=1) | 30B (tp=2) | 30B (tp=4) | 65B (tp=2) | 65B (tp=4) | 65B (tp=8) |
---|---|---|---|---|---|---|
vLLM | 1923.29 | 1368.42 | 1654.71 | 415.98 | 1230.65 | 1315.74 |
vLLM-fast | 2451.05 | 1618.01 | 2013.47 | 450.61 | 1426.67 | 1552.44 |
Improvement | 1.27x | 1.18x | 1.22x | 1.08x | 1.16x | 1.18x |
Noting that those numbers are the result of a single run and could change on multiple runs.
I also attached my workarounds that I conducted when facing some issues in my Conda environment.
# Needed to downgrade protobuf
pip uninstall protobuf
pip install protobuf==3.20.3
# Needed to reinstall Ray
pip install "ray[default]"
# Needed to downgrade pydantic
pip uninstall pydantic
pip insatll pydantic==1.10.11
Creating a compression alg here which will both be efficient distributed-wise, compression-wise, etc might be rather hard
Hi @naed90 , Thanks for your quick and detailed response. Your work is really a great improvement.
Yes, you are absolute correct and that's what I mean.
I was curious about if we could get further speedup with the larger batchsize by offloading kv-cache in layers to CPU RAM. But I am not sure if those asyncMemcpy operations in different stream will affect the computation workload in some extend.
I believed that it wouldn't help to imrove the throughput even we can run higher batch size/more tokens after your explanation. Thank you.
I agree that compression
would be a very good direction to explore. I am not sure if you are familliar with GPTQ, which can compress weight to 1/4 sizes. But we need to prove that these KV-Cache values haven't too much outliers or it will introduce huge reconstruction error.
One more possible step to improve the performance is to force the stream wait on the same stream.
https://github.com/vllm-project/vllm/blob/c894836108732d0cbb6fce15aeda8de1218a380d/vllm/model_executor/layers/attention.py#L180
It waits on the default(0) stream for now.
In my expriment (opt-6.7b), Waitting on the cache_stream
can reduce model-forward time about 3-5%.
Besides, The model (opt-6.7b) running time is contributing no more than 40%. I doubt that the reason of schedule logic cost too much time is that written in Python Lauguage.
Besides, The model (opt-6.7b) running time is contributing no more than 40%. I doubt that the reason of schedule logic cost too much time is that written in Python Lauguage.
Completely agree that the compression of kv caches is a very interesting yet challenging project :)
Yeah, the current python code in vllm really is taking lots of time. Here is a great candidate for a 10-20% improvement: this. The line tokenizer.convert_tokens_to_string(output_tokens)
currently takes roughly 10-20% of the total time of the program. This is because we iterate sequentially over all the sequences and call this line once per sequence.
The easiest solution would be to somehow decode all the sequences in parallel. It can't be done using python threading, since python has just 1 thread running and multiplexes virtual threads on top of that (as far as I know). It feels that spinning up more processes just for this could be a bit annoying. There has to be a better solution :)
Anyhow, it's a great place to start that would give another 10-20% throughput!
One more possible step to improve the performance is to force the stream wait on the same stream.
It waits on the default(0) stream for now.
Interesting! I was not aware of this. Did you try to change it and see how much it helps?
Interesting! I was not aware of this. Did you try to change it and see how much it helps?
Yes, I did. I put cache_stream to a member of InputMetaData, and then change the code to
cache_event.wait(input_metadata.cache_stream)
It gives me 5% speedup in V100
The easiest solution would be to somehow decode all the sequences in parallel. It can't be done using python threading, since python has just 1 thread running and multiplexes virtual threads on top of that (as far as I know). It feels that spinning up more processes just for this could be a bit annoying. There has to be a better solution :)
Anyhow, it's a great place to start that would give another 10-20% throughput!
Interstesing. I have a rough statistic for different stage of the generate
pipeline. But I didn't aware that convert_tokens_to_string
is so time-consuming.
I think you would be interested in this BUG.
It makes the results generated by VLLM is not expected.
Update: Sorry, I made a mistake here.
I think you would be interested in this BUG. It makes the results generated by VLLM is not expected.
Interesting point that you mention. I didn't look too much into the beam-search code. Did one of the project team members take a look?
Interesting! I was not aware of this. Did you try to change it and see how much it helps?
Yes, I did. I put cache_stream to a member of InputMetaData, and then change the code to
cache_event.wait(input_metadata.cache_stream)
It gives me 5% speedup in V100
Nice! Got a commit? I'll try to run it too.
Nice! Got a commit? I'll try to run it too. Sure, Try this one,https://github.com/wejoncy/vllm/commit/0945030eeea9f9d5702671f55cd4e68332aae9a5
@naed90 @oleitersdorf Sorry that I was traveling last week and didn't get a chance to take a detailed look into this issue and its related PRs. Finally got the time to look at it, and I have to say what you have done is impressive. This is very helpful for the vLLM community to understand the performance issues in the current vLLM. Thank you for all your efforts!
Regarding the points mentioned in the issue:
We are happy to discuss more! I believe @WoosukKwon can also comment more on why we have 16 bytes in our kernel.
@naed90 @oleitersdorf Fantastic! Thanks for exploring vLLM and pitching the wonderful idea with detailed explanations. I didn't expect that our attention kernel has such a problem. Thanks again for your finding and solution!
As can be seen above, the main kernel of the program has very low L1/L2 cache utilization. Taking another look at the kernel, it seems that 4x more data is being read from global memory than is needed -- each read to fetch a key-segment from the KV cache is translated into a read of 64 bytes, yet, the thread only uses 16 of those bytes. As the keys in the KV cache have shape [num_blocks, num_heads, head_size/x, block_size, x], these reads are very spaced apart. Similar issues happen with value fetching.
Speaking of this question, I have no idea at the moment. The data layout of the key cache is actually from the original FasterTransformer kernel, and it is optimized for memory access (for the value cache, I slightly modified the layout to better fit into the block-based memory layout in paged attention). Specifically, the key cache shape [num_blocks, num_heads, head_size/x, block_size, x]
is designed to ensure that the threads in a warp read a contiguous chunk of memory even though they are reading different tokens in a token block. So it is supposed to coalesce the threads's global memory access and thus maximally utilize the GPU memory bandwidth.
Thanks again for very detailed profiling and sharing your findings. I will look into it and double-check whether I misconfigured anything.
@naed90 @oleitersdorf Sorry that I was traveling last week and didn't get a chance to take a detailed look into this issue and its related PRs. Finally got the time to look at it, and I have to say what you have done is impressive. This is very helpful for the vLLM community to understand the performance issues in the current vLLM. Thank you for all your efforts!
Regarding the points mentioned in the issue:
- Using shared memory to accelerate query memory reading makes a lot of sense. We will review [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420, test the performance on our side and merge that into the main branch. Regarding the memory reading overhead of keys and values, many recent models include multi-query attention (MQA) or grouped-query attention (GQA) (See Optimize MQA Kernel #452). In this case, one key/value head can be shared by many query heads, there is a potential kernel optimization opportunity here.
- We were aware of the inefficiency in our sampling codes. I believe what you have implemented is a valid quick fix for performance. However, as you mentioned, vLLM supports many different sampling methods being applied to the requests in the same batch, which includes beam search. Previously we are thinking of having a pure C++ sampler or a CUDA sampling kernel to mitigate the performance issue. We are open to your input on this as well.
- Regarding parallel tokenization, I think the biggest issue here is that different models have different tokenizers. It's hard to have a one for all solution for this. Again, any inputs are valuable here.
We are happy to discuss more! I believe @WoosukKwon can also comment more on why we have 16 bytes in our kernel.
Hey @zhuohan123 Hope your travels were pleasant :) Glad to help out any way we can!
import threading
and other libraries just multiplex over this thread to hide I/O costs). So, if we do this in python, it would require spinning up multiple processes which could be a mess. Perhaps the right solution is using cpp (where we can have real threads), but then we would probably run into the problem that the tokenizers for some models are written in Python (I guess?).@naed90 @oleitersdorf Fantastic! Thanks for exploring vLLM and pitching the wonderful idea with detailed explanations. I didn't expect that our attention kernel has such a problem. Thanks again for your finding and solution!
As can be seen above, the main kernel of the program has very low L1/L2 cache utilization. Taking another look at the kernel, it seems that 4x more data is being read from global memory than is needed -- each read to fetch a key-segment from the KV cache is translated into a read of 64 bytes, yet, the thread only uses 16 of those bytes. As the keys in the KV cache have shape [num_blocks, num_heads, head_size/x, block_size, x], these reads are very spaced apart. Similar issues happen with value fetching.
Speaking of this question, I have no idea at the moment. The data layout of the key cache is actually from the original FasterTransformer kernel, and it is optimized for memory access (for the value cache, I slightly modified the layout to better fit into the block-based memory layout in paged attention). Specifically, the key cache shape
[num_blocks, num_heads, head_size/x, block_size, x]
is designed to ensure that the threads in a warp read a contiguous chunk of memory even though they are reading different tokens in a token block. So it is supposed to coalesce the threads's global memory access and thus maximally utilize the GPU memory bandwidth.Thanks again for very detailed profiling and sharing your findings. I will look into it and double-check whether I misconfigured anything.
Hey @WoosukKwon and nice to meet you.
We looked into this more. It turns out that everything is fine :)
We reshaped the tensor, and this drastically increased L1/L2 hit rates (L1 hit rate went from 1% to 43%, and L2 went from 25% to 33%), however, the total runtime did not decrease.
When taking a second look, we realized that no data was being read twice from the global memory, so all is good -- the easiest way to verify this is that Nvidia Nsight Compute writes how many bytes the kernel read from global memory, and we saw that it is only roughly 5% higher than the theoretical requirement (which can be computed by summing the total lengths of the sequences, times key size, times 2, etc).
Bottom line: looks like after the fix we submitted in the PR, the kernel is now doing the best it possibly can (since it's reading only roughly 5% more data than the theoretical rough estimate, and the DRAM read/write usage is roughly 80-90% on average). :)
Incredible job, this report inspires me a lot, thanks. @naed90 For the inefficient sampling methods, I agree with @zhuohan123 , we should implment it in a robustness and standard way. FYI, Nvidia will release a TRT-LLM in the near future, which is a combination of TRT and FasterTransformer. For some confidentiality reasons, I can't tell the details about the release, but TRT-LLM does implement sampling kernel. If by then vLLM haven't finished this part work, I think we can reference TRT-LLM and have a try.
BTW: I think this maybe also a good direction to accelerate inference: Inference with Reference: Lossless Acceleration of Large Language Models
@naed90 @oleitersdorf Sorry that I was traveling last week and didn't get a chance to take a detailed look into this issue and its related PRs. Finally got the time to look at it, and I have to say what you have done is impressive. This is very helpful for the vLLM community to understand the performance issues in the current vLLM. Thank you for all your efforts!
Regarding the points mentioned in the issue:
- Using shared memory to accelerate query memory reading makes a lot of sense. We will review [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel #420, test the performance on our side and merge that into the main branch. Regarding the memory reading overhead of keys and values, many recent models include multi-query attention (MQA) or grouped-query attention (GQA) (See Optimize MQA Kernel #452). In this case, one key/value head can be shared by many query heads, there is a potential kernel optimization opportunity here.
- We were aware of the inefficiency in our sampling codes. I believe what you have implemented is a valid quick fix for performance. However, as you mentioned, vLLM supports many different sampling methods being applied to the requests in the same batch, which includes beam search. Previously we are thinking of having a pure C++ sampler or a CUDA sampling kernel to mitigate the performance issue. We are open to your input on this as well.
- Regarding parallel tokenization, I think the biggest issue here is that different models have different tokenizers. It's hard to have a one for all solution for this. Again, any inputs are valuable here.
We are happy to discuss more! I believe @WoosukKwon can also comment more on why we have 16 bytes in our kernel.
May I ask in which case do we need to support "However, as you mentioned, vLLM supports many different sampling methods being applied to the requests in the same batch, which includes beam search"? @zhuohan123 @naed90
@naed90 Thanks for your great work. How could you profile vLLM using nsight system? I am also trying to optimize vLLM. Thanks a lot.
@naed90 -- I'd love to understand what is the unit of THREAD_GROUP_SIZE
in the CUDA kernel. Is it,
Or is it something else
Hi i'm the maintainer of LiteLLM and we allow you to max throughput by load balancing between multiple vLLM endpoints. Thought it would be useful for people on this thread, I'd love feedback if not
Here's the quick start, to use LiteLLM load balancer (works with 100+ LLMs) doc: https://docs.litellm.ai/docs/simple_proxy#model-alias
model_list:
- model_name: openhermes
litellm_params:
model: openhermes
temperature: 0.6
max_tokens: 400
custom_llm_provider: "openai"
api_base: http://192.168.1.23:8000/v1
- model_name: openhermes
litellm_params:
model: openhermes
custom_llm_provider: "openai"
api_base: http://192.168.1.23:8001/v1
- model_name: openhermes
litellm_params:
model: openhermes
custom_llm_provider: "openai"
frequency_penalty : 0.6
api_base: http://192.168.1.23:8010/v1
litellm --config /path/to/config.yaml
curl --location 'http://0.0.0.0:8000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "openhermes",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
],
}
'
Hi @naed90, are there any PRs in the works for what you've written about in this issue? I see that your original PR #420 is already merged and wondered if this issue can be closed?
Hi @naed90, do you have any plan on the development of the kv cache compression? :)
+34% higher throughput?
TLDR: Seeing vLLM has been really fascinating! @oleitersdorf and I investigated whether we could further accelerate vLLM by profiling its performance with GPU counters. Currently, we believe we have achieve a speed-up of 1.34x for the benchmark reported on the vLLM website. As the vLLM site claims "24x higher throughput compared to HF and up to 3.5x higher throughput than TGI", and the techniques we show below improve a further 1.34x, then vLLM has the potential to have a 29.5x higher throughput compared to the baseline HF and 4.7x over TGI.
Many thanks to the authors for developing this really exciting work -- we had a great time reading your code! We are sure that you probably already thought of the improvements we show below (and maybe just didn't get to them), and would love to hear your thoughts.
Below we write out the optimizations we found, and list several open directions which could hopefully speed up even further. The goal of this issue is to encourage discussion and brainstorm potential improvements -- some parts are still a POC and require more work to make reach production-ready levels. For the part which is already production-ready, we opened this PR.
This issue has 3 sections:
single_query_cached_kv_attention
)Benchmark
We test on the benchmark of using LLaMA13B to complete 1000 randomly sampled prompts from ShareGPT. For each sequence, we create just one completion (matching the benchmark on the project website). To run the benchmark, begin by cloning vLLM, downloading the dataset from the project website, and running the following command.
We begin by running the above on a clean clone of vLLM on an A100 (80GB), to receive the following output.
This rate of 4.02 sequences completed per second translates to 241.2 seq/min. On the project website, a throughput of 154.2 seq/min is reported for running the same model, yet on an A100 (40GB). For this issue, we are using an A100 (80GB), and so we set the reference point at 4.02 seq/sec. By the end of this issue, we get 5.41 seq/sec, achieving an improvement of 1.34x.
Analyzing
single_query_cached_kv_attention
The main kernel in vLLM is
single_query_cached_kv_attention
, which is used to compute the forward pass of an attention layer, using the KV cache designed in vLLM. We begin by profiling this kernel using NVIDIA Nsight Compute to check for potential improvements.A preliminary look through NVIDIA Nsight Compute reveals several points to tackle. As seen above, the kernel underutilizes the SM resources both in terms of compute and memory -- uses only roughly 15% of the compute and 50% of the memory bandwidth.
As seen here, each SM has no warps ready to schedule 5 out of 6 times. Thus, we begin by trying to identify the culprit for why the warps are stalling.
The kernel works roughly as follows. Each block is responsible for computing the entire attention mechanism for the last token of one specific sequence and one specific head in that token. Each block is 128 threads by default (4 warps).
Q_vec q_vecs[NUM_VECS_PER_THREAD];
. The threads of the block are split into 'groups' such that each group loads the entire query. On our configuration (default configuration + running LLaMA13B on an A100, 80GB), each thread group has 2 threads. That means that every 2 threads in the warp will read the entire head of the query into their own registers/local memory -- i.e., each thread holds half of the query head.To find which stage is holding the warps back, we observe the assembly analysis in Nsight Compute. The warps wait a lot of time on the commands in this screenshot. As we can see, there is a global load happening, and then roughly 4% of the time stalls happen there (a value is loaded from global memory into register R78, and then warps halt before executing the instruction highlighted as to run the instruction they must wait for the load into R78 to finish). Notably, further below, this code repeats 14 times in total (due to loop unrolling), which causes most of the stalls in this kernel.
These commands are part of the first two steps above, where the threads load the query head and key heads (note: compiling with source code so that Nsight Compute will show the lines in the source the warps are stuck on could help, but it also can significantly change the assembly outputted; therefore, we work directly with the assembly instead -- if someone has a better solution, we would love to hear 🤩). Specifically, these commands are the load of the key heads. As not much can be done about the loading of the key heads, we focus on the query heads which are also loaded from global memory.
The query heads are read multiple times from global memory -- specifically, in our case (default configs, LLaMA13B, A100, 80GB), every byte of the query is read by 64 different threads. Therefore, we begin by optimizing this such that each byte in the query head is read only by exactly 1 thread, and then stored in shared memory for other threads in the block to access.
We replace this code:
with this code (see this PR):
Running the benchmark gives:
As the reference point above is 4.02 seq/sec, the result of 4.43 seq/sec that we get here is a 1.10x improvement.
At this point, we rerun the nsight compute analysis above. As we can see, the kernel is now at a rather high memory bandwidth utilization (86%). We tried several other improvements (see this section) to squeeze a bit more performance out of this kernel, yet they did not improve the overall runtime of the benchmark. Therefore, as the memory bandwidth utilization is rather high and it appears that the kernel is loading the minimal amount of data it needs to from global memory (it has to load the keys and values...), then we decided to stop looking at the kernel itself and began looking elsewhere.
Overall Program Analysis
Observe the following report generated by using NVIDIA Nsight Systems to profile the entire program execution.
As we can see, roughly half the time the program does not use the GPU at all (observe that DRAM Bandwidth, SM Warp Occupancy, etc, are practically zero half the time). This is time which is spent in the CPU, running the python code which surrounds the model. We investigate and find that the culprit is the sampling of the generated tokens. Observe the forward code of the class
LlamaForCausalLM
.It turns out that half the program time is spent in the above call to
self.model
and half in the call toself.sampler
(note: this is not possible to see by timing the Python only, as the kernels are run on the GPU asyncronously and the CPU waits for them later on).Specifically, the sampler performs the following for each sequence being completed (link).
That is, for each sequence,
probs
is the generated probabilities for the next token. The above code focuses on a specific sequence and samples just for that sequence. We replace feeding the entire matrix (num_sequences x token_space) intotorch.multinomial
to perform the sampling for all sequences at once. The following is a POC-level snippet which does this for the current benchmark (sampling just 1 token for each sequence, no beam-search or any other technique).As the code change is rather long, we do not write it out here -- please refer to the following commit to see the change (note that the code is currently meant as a POC and not production-grade).
Rerunning the benchmark gives the following.
As the reference point above is 4.02 seq/sec, the result of 5.18 seq/sec that we get here is a 1.28x improvement so far.
We rerun nsight systems and observe the following.
Indeed, the time between GPU calls drastically shrunk. We zoom in to see what remains there. It seems that there are many small 4 byte reads from the GPU to the CPU. The culprit is the following line, where the logprobs of the chosen tokens are read from the GPU to the CPU one-by-one.
These many small reads have a huge overhead and incur high sync costs. Fixing this by coalesing the reads requires some manuvering in the code (it turns out that there is another small read in another place). See this commit for a POC. We rerun the baseline and get the following.
As the reference point above is 4.02 seq/sec, the result of 5.41 seq/sec that we get here is a 1.34x improvement.
Further potential directions & Ideas which did not pan out
Potential idea: Cache Utilization
As can be seen above, the main kernel of the program has very low L1/L2 cache utilization. Taking another look at the kernel, it seems that 4x more data is being read from global memory than is needed -- each read to fetch a key-segment from the KV cache is translated into a read of 64 bytes, yet, the thread only uses 16 of those bytes. As the keys in the KV cache have shape
[num_blocks, num_heads, head_size/x, block_size, x]
, these reads are very spaced apart. Similar issues happen with value fetching.Could it be that reshaping this cache such that each thread uses all 64 bytes that it reads at once would save 4x on memory bandwidth?
We are not 100% certain that only 16 of each 64 byte read is used, as the assembly seems to point to all 64 bytes being used, while the source code seems to imply only 16. Further investigation is needed here (and includes rewriting some of the other kernels/python code to reshape the cache). Therefore, we would appreciate the author's input here before we try implementing this change (we assume there is a reason this shape was originally chosen) -- i.e., are we missing something :)?
Overall, it seems like it potentially would be worthwhile to investigate the memory loads of this kernel. Observe the following two comments from NVIDIA Nsight Compute:
Potential idea: Parallel Tokenization
There is a potential for a further improvement of roughly 10% by parallelizing tokenization after sampling. Specifically, this line get called sequentially for every sequence when we sample convert each sampled token into text. This takes roughly 10% of the execution time -- time where the GPU sits completely idle.
Failed idea: Batch reading from the block tables
At the start of the for-loop fetching the keys, the physical block number is read from the block table (global memory). This line stalls many threads. It turns out that all the threads in a warp read the same position in the block table (which is ok, since, iirc, as only 1 read is sent to the memory and its results are broadcast to the threads automatically). To try to reduce the stall, we can have each thread read a different value, and then only once in every 32 for-loop iterations would we go out for a global memory read.
We implemented this, yet it had no affect on the runtime. We believe it's due to both the fact that the kernel is memory-bound anyhow, and, it seemed that the stall just moved to being concentrated on the key reads (which come right after).
Failed idea: atomicAdd for the last aggregation
We tried replacing the final aggregation in the kernel with atomicAdd between the 4 warps in the block. This degraded the results we observed.
Final Thoughts
vLLM is truly a thought-provoking and intriguing concept! We very much enjoyed delving into this code and are very eager to see how far this can be optimized!! Who knows if it's possible to go even much faster :)
Looking forward to hearing your thoughts!