ggerganov / llama.cpp

LLM inference in C/C++
MIT License
64.91k stars 9.31k forks source link

Feature Request: Tensor Parallelism support #9086

Open ClarkChin08 opened 3 weeks ago

ClarkChin08 commented 3 weeks ago

Prerequisites

Feature Description

Tensor parallelism is a a critical technique employed to train and inference from very large language models by splitting the actual computations/tensors across multiple compute devices.

Motivation

In our previous implementation on Xeon CPU, tensor parallelism(TP) can significantly reduce the latency on inference. <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40">

model | precision | TP size | input_size | nex_token_time/ms -- | -- | -- | -- | -- llama2-70b | q4_j | 1 | 32 | 191.91 llama2-70b | q4_j | 2 | 32 | 120.87 llama2-70b | q4_j | 4 | 32 | 86.15 llama2-70b | q4_j | 1 | 1024 | 197.18 llama2-70b | q4_j | 2 | 1024 | 129.25 llama2-70b | q4_j | 4 | 1024 | 91.76 llama2-70b | q4_j | 1 | 2012 | 204.85 llama2-70b | q4_j | 2 | 2012 | 127.31 llama2-70b | q4_j | 4 | 2012 | 100.44

Notice: TP size= 1 means not use TP.

Possible Implementation

In our TP implementation, we adopt the method of pre-splitting the corresponding weights, so the time consumed for this part is one-time and does not affect inference performance. Meanwhile, another major factor impacting performance is 'all reduce'. Since each node computes partial and incomplete results, it is necessary to perform 'all reduce' on the output data. But all reduce is relatively time-consuming, interestingly, by using a reasonable splitting and combining method, primitives can be operated independently across nodes, which is very helpful for performance optimization. Thus, a rational splitting method becomes extremely important.

Taking the FFN module as an example, if the first matmul splits by column and computes the matmul with input, it will result in two unrelated sub-matrices on each node. These two sub-matrices, when performing the second matmul operation, can proceed directly without having to perform 'all reduce' if splitting by rows. Thus, the entire FFN module only requires one 'all reduce', meaning that with properly tailored split implementation, even with multiple matmul operations, only one 'all reduce' operation may be needed. We ignored the element-wise operations between matmul as they would not influence the results. image The scenario for the attention module is more complex. As shown in the following figure, a rational split can make it so that the entire attention module only requires one 'all reduce' operation, thus greatly saving synchronization time. image

Chocobi-1129 commented 3 weeks ago

Not sure if this related to #4014

ClarkChin08 commented 2 weeks ago

Not sure if this related to #4014

To reduce the communication time and improve latency, we should minimize the use of 'all reduce.' My proposal includes two improvements:

  1. Splitting the Weight Tensor Before Inference:

    • We can split the weight tensor and distribute partial weights across each TP (Tensor Parallel) node during the 'llm_load_tensors' phase. This involves adding three specific tensor split types when creating the tensor. The different tensor split methods will allow us to avoid 'all reduce' operations between two matrix multiplications (matmuls) just like the pictures showed above. Additionally, element-wise operations will also only calculate partial tensors. image

    • We can create weight tensors with the split type as shown in the following illustration: image

  2. Inference with Splitted Weights:

    • After splitting the weights, each node will only infer a part of the model. The only change is that each attention block and MLP block will have one 'all reduce' operation. These 'all reduce' operations always follow the matmul with weights split by column.
    • We can fuse the 'all reduce' with the matmul operation by checking the weight split type, as illustrated below: image

By setting the tensor's split type during the weight loading phase and adding support for 'all reduce' during matmul calculations, we can reduce the number of 'all reduce' operations to just twice per layer. The computational workload that is only 1/world_size of the original, thereby significantly improving latency. The relevant pull request will be submitted soon, and any comments are welcome.

Vladonai commented 2 weeks ago

I would like this method to support not only newer GPUs. I for example have 4 Tesla P40s and would like to get a noticeable acceleration.