Open 0xymoro opened 8 months ago
Hi @0xymoro , when inflight batching is enabled, that said, if you enable gpt attention plugin, fused mha, paged kv cache and remove input padding, (they are all enabled by default on the latest main branch) max_batch_size
are no longer affecting the activation memory required by TensorRT engines. That means you can try set max_batch_size
as large as you want without getting OOM, in runtime, the actual batch size will be limited by available GPU memory.
What really affects the activation memory required by TensorRT engines is max_num_tokens
, so you need to be careful when setting that argument, and not setting that argument too large. The setting of max_batch_size
and max_num_tokens
should be decoupled because padding in the inputs is removed.
For a 70b model, quantized to fp8, running on 2xH100s TP2.
If we go back to your case, the inflight batching knobs are recommended to be enabled, and when you're estimating max_num_tokens
, the max_batch_size
that get involved in the estimation should be the actual max batch size that your GPU memory can hold, so it cannot be too large, then you can get a estimation for max_num_tokens
. If you really don't know what max_num_tokens
should be set, I would suggest you to start from 8192(empirical value), and try tuning by increasing it and see which one can bring you the best performance.
I know that setting max_num_tokens
is a little bit tricky, we're trying to make it easier. Hope that helps, and feel free to let me know if that's still not clear to you, thank you.
Got it, I do think 8192 seems conservative - I see on your example performance benchmarks a similar setup can run batchsizes of at least 64, and if 8192 is max num tokens that means at most ~4 requests with 2048 input can be run in parallel, this does seem a bit on the low side.
I think it'd help for documentation if there were a few numbers & setups you have as examples of what you found to be optimal, even if those setups don't apply to many people's setups it will help to see how model size, quantization, input & output sequences affect this number. I see some similarities with TGI and it auto-infers some of the max num tokens, and the values seem to be higher there so I think there's likely a lot more room to push TRTLLM to go higher than 8192 but I may be confounding what the two numbers mean across the platforms.
if 8192 is max num tokens that means at most ~4 requests with 2048 input can be run in parallel
@0xymoro Note that generation requests will most likely occupy most of all requests, and the number of tokens of one generation request is 1, which means that it can contain, for example, 3 context requests with length 2048, plus 2048 generation requests if the GPU memory is enough.
Thanks for the nice suggestions, and yes, I agree that more documentation should help. We're working on the features as well as the documentation to make user experience better, thanks for your support.
@kaiyux would it be possible to provide an example scenario/load profile where using a non-default value for max_num_tokens
would lead to better results than the default?
Hi @0xymoro do u still have further issue or question now? If not, we'll close it soon.
From documentation it gives the detail on estimating max num tokens. However less clear for me is how to go about estimating the max batch size for the hardware that's existing and typical/recommended max num tokens for a given hardware, model, quantization setup. Would love some more documentation here, but as an example:
For a 70b model, quantized to fp8, running on 2xH100s TP2. If I have max input len 4096, how should I go about setting a max batch size that makes sense? If 64 max batch size, my max num tokens would be around ~52k (from the formula assuming 20% of requests are in prefill). If 128 max batch size, it would be ~104k. It would be great to see in the performance benchmarks, what max num tokens were used, to see given the hardware & setup how much it can handle and people can adjust batchsize/etc. Is this the correct intuition?
@kaiyux @juney-nvidia