vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.7k stars 4.09k forks source link

[Performance]: Long AllReduce wait time on 1 device with tensor parallelism #5792

Open wenscarl opened 3 months ago

wenscarl commented 3 months ago

Proposal to improve performance

Propose synchronizing the broadcast of tensor_dict at the beginning of each decoding step or block the process after broadcast.

Report of performance regression

In the decoding stage, after matrix multiplications utilizing tensor parallelism, an all-reduce operation follows, which implicitly synchronizes the processes. However, the asynchronous broadcast of tensor dictionaries (code available here) at the start of each decoding step causes CUDA kernels to launch at quite different times across processes. This leads to the scenario depicted in the following image. image (12) and image (13) @youkaichao

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

CUDA_VISIBLE_DEVICES=0,1,2,3 nsys profile -t cuda,nvtx python benchmarks/benchmark_throughput.py --model=meta-llama/Meta-Llama-3-70B-Instruct --quantization=fp8  --dataset=/workspace/sw3/vllm/ShareGPT_V3_unfiltered_cleaned_split.json --output-len=64 --num-prompts=50 --enforce-eager -tp=4
youkaichao commented 2 months ago

thanks for the report! We do plan to remove this broadcast call. you can track the progress at https://github.com/vllm-project/vllm/issues/6241 . once we solve that issue, the driver process will send a lightweight python object to all processes, and each process prepare input themselves, so we don't need the broadcast tensors.

eileenzhujuan commented 2 months ago

Proposal to improve performance

Propose synchronizing the broadcast of tensor_dict at the beginning of each decoding step or block the process after broadcast.

Report of performance regression

In the decoding stage, after matrix multiplications utilizing tensor parallelism, an all-reduce operation follows, which implicitly synchronizes the processes. However, the asynchronous broadcast of tensor dictionaries (code available here) at the start of each decoding step causes CUDA kernels to launch at quite different times across processes. This leads to the scenario depicted in the following image. image (12) and image (13) @youkaichao

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

CUDA_VISIBLE_DEVICES=0,1,2,3 nsys profile -t cuda,nvtx python benchmarks/benchmark_throughput.py --model=meta-llama/Meta-Llama-3-70B-Instruct --quantization=fp8  --dataset=/workspace/sw3/vllm/ShareGPT_V3_unfiltered_cleaned_split.json --output-len=64 --num-prompts=50 --enforce-eager -tp=4

Hi, I am curious about this proposal, as I met the similar problem. When I set tp_size=4, one of the rank(not the tp_rank=0 one) would appear the phenomenon that kernel launch turns much slower. As a result, each attention become slower with an all_reduce at the end. So, you meant that make the broadcast at the beginning of each decode step synchronize immediately would relieve the problem. Is it?