NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
7.37k stars 796 forks source link

Question regarding the weird operation of GPT-J 6B's XQA on MLPerf 4.0 version TensorRT-LLM. #1874

Open bongwonjang opened 2 days ago

bongwonjang commented 2 days ago

Hello TensorRT-LLM experts! I have a question regarding the weird operation of the XQA kernel function supported in NVIDIA's official MLPerf 4.0 version of TensorRT-LLM.

First of all, I want to tell you the environment in which I conducted my tests.

When testing the GPT-J 6B (FP8 quantized model, beam_width = 4) using the customized cnn_daily_mails datasets (I truncated the input length from 128 ~ 512), I observed that the computation speed somewhat increased at the specific sequence length.

Specifically, I found that when the input length became a multiple of 256, there was a slight speed improvement in generation phase compared to the other input lengths.

When measured with Nsight Compute, I observed the following differences in memory transfer size. The image below shows the measurements taken during the first generation phase after the context phase when the input sequence length is 255 or 256.

First, in both images, it can be figured out that approximately 8MB of data (e.g. KV Cache: 2 * 1byte * 256 * 4096 * 4, the last number indicates batch_size) is read from device memory to L2 cache.

However, the truly strange part is the amount of data that shared memory reads from L2 cache.

The former reads 50MB from the L2 cache, while the latter reads only 12MB. The only difference between the two is the input sequence length is only longer by 1. What's even more strange is that this phenomenon similarly occurred with input sequence lengths of (511, 512) or (767, 768).

To understand this results, I checked the NVIDIA Nsight Memory Analysis video, but I couldn't get the cause of the strange XQA's operation.

I know that the source code for XQA will not be opened to public. However, it would be very helpful to know if this issue has been resolved in the current version of TensorRT-LLM or if it remains an area for improvement in XQA. I hope this information would greatly assist in the use of TensorRT-LLM! :bow:

lowsfer commented 1 day ago

What you observed is a result of the kernel design. XQA is efficient for beam search because in beam search, the KV cache entries generated by the input tokens are always shared among all beam search candidates. XQA is optimized for this case. XQA uses one thread block to handle all beam search candidate sequences. When it iterates through the KV cache, it will load one tile (256 for your case) of tokens per iteration. If all tokens in this tile is from input, then XQA knows that it only needs to load them once because they are uniform for all beam search candidates. This optimization happens with granularity of a tile (256 tokens for your case). If a tile contains at least one generated token, then XQA will fallback to normal mode for this tile, i.e. load once for each beam search candidate. If the tile contains both input and output tokens, the input tokens in this tile will be loaded multiple times and may cause extra L2 traffic.

Does above explanation makes sense to you?

We are not yet checking output tokens to see if some of them are also uniform for all beam search candidates. It's one potential optimization for generation-heavy tasks. We are aware of this but it's not yet prioritized.

bongwonjang commented 3 hours ago

Thank you @lowsfer for providing a detailed answer about what I was really curious about! After reading your comment, I reviewed the Beam Search method of PagedAttention and its code, considering the optimized XQA operation. Indeed, I can see that the method you provided is the best approach. Of course, I don't know the actual internal code!

:thinking: However, I still have unresolved questions and would like to propose an optimization candidate for the current XQA operation (from the perspective of memory access rather than computation), although the priority is not high.

The following figure illustrates my understanding of XQA operation. Let's assume that one Tile reads 256 tokens.

By examining the above and below cases, we can calculate how much memory size (especially, KV Cache) is transferred. We have already defined beam width = 4 and batch size = 4.

:thinking: Here, I met an unresolved issue. When I analyzed with Nsight Compute, the memory amounts read are actually 50.53MB and 12.92MB, respectively.

Interestingly, these values are about 1.5 times the theoretically calculated values (as indicated by the SPECIAL NUMBER in the figure). Moreover, this 1.5x ratio remained nearly consistent across various input sequence length!

In summary, the current total amount of memory transfer size from the L2 cache to shared memory is 50% more than the expected value (although I considered only the KV cache). I think this result indicates a potential optimization problem.

Due to my poor CUDA abilities and the inability to read the actual code, it is difficult to determine whether this issue arises from bank conflicts or uncoalesced memory access.

If you could clarify

  1. whether I have misunderstood, :sweat_smile:
  2. whether this is an optimization issue that needs to be addressed,
  3. whether the additional data (50% size of KV Cache) should be read and used in the XQA kernel,
  4. or if it is a matter that only TensorRT-LLM core developers can access,

it would be greatly enhance our trust in the operation of XQA and TensorRT-LLM!

If it is difficult to answer, I would appreciate it if you could let me know that you cannot provide a response. I would be delighted if this analysis could contribute, even slightly, to the improvement of XQA operation and TensorRT-LLM.