NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.34k stars 936 forks source link

Why is there no Alltoall function in MoE implementation? #989

Open YJHMITWEB opened 8 months ago

YJHMITWEB commented 8 months ago

Hi, I am running and profiling the code of the Mixtral implementation, however, neither in the code nor in the profiling, did I find any Alltoall operations.

I built the TRT engine using the following config:

python ../llama/build.py --model_dir ./Mixtral-8x7B-v0.1 \
                --use_inflight_batching \
                --enable_context_fmha \
                --use_gemm_plugin \
                --world_size 4 \
                --tp_size 4 \
                --output_dir ./trt_engines/mixtral/TP \
                --moe_tp_mode 1 \
                --max_output_len 2048

I tried both --moe_tp_mode 1 and --moe_tp_mode 2, but seems they just end with the same tensor parallelism, with no expert parallelism enabled. Also in the nsight profiling, there are only Allreduce and Allgather calls, which seems insufficient for expert parallelism.

djns99 commented 7 months ago

Hi @YJHMITWEB, thanks for reaching out. You are correct in your observation that TRT-LLM only uses an allreduce (see here).

The allreduce step is a convenience so that the data flow is the same as TP. Instead of broadcasting each token to all the other nodes and then doing scale/bias steps (this would use all-to-all). We instead do the rescaling locally on each node and then allreduce on the results using zero tensors for uninitialised tokens.

There may be cases where the all-to-all pattern is better and we will continue actively investigating this option.

In general though, we recommend Tensor Parallelism because of the load balancing issues that are inherent in Expert Parallelism