b4rtaz / distributed-llama

Tensor parallelism is all you need. Run LLMs on weak devices or make powerful devices even more powerful by distributing the workload and dividing the RAM usage.
MIT License
1.02k stars 68 forks source link

[Feature Suggest] From All-Reduce to Ring-All-Reduce #69

Open zhengpeirong opened 1 month ago

zhengpeirong commented 1 month ago

Dear author,

Challenge and solution

This repository has implemented Tensor Parallel, which facilitates the system by distributing the computation workload evenly to each node, achieving nearly linear acceleration in terms of inference time. However, the communication workload is not distributed. In other words, the transfer time will increase with the number of workers. This situation can be solved by changing all-reduce to ring-all-reduce, which distributes the transfer data workload to every worker.

Let me briefly introduce the concept of all-reduce and ring-all-reduce.

All-reduce

image This master-worker architecture has currently been implemented. It can be changed into ring-workers.

Ring-all-reduce

The ring-all-reduce algorithm is divided into two stages.

stage 1

First, Distribute P workers on a ring and divide each worker's data into P parts. In your case, the hidden_dim is divided into P parts. image

Next, look at the k-th worker, who will send the k-th data to the next worker and receive the k-1-st data from the previous worker. image

Afterwards, the worker will integrate the received k-1-st data with their own k-1st data, and then send the integrated data to the next worker. image

After P-1 cycles, each worker will include a copy of the final integration result. image

stage 2

In the second stage, each worker sends the integrated part to the next worker. Workers can update the corresponding part of its data after receiving the data. After P-1 cycles, each worker will include a full copy of the final integration result. This result is the same as All-Reduce.

Assuming that each worker's data is a vector of length hidden_dim = h, the amount of data sent or received by each worker is 2(P-1)*h/P, almost independent of the number of workers P.

When P=1, the transfer data is 0; When P=2, the transfer data is h; When P=4, the transfer data is 1.5*h, less than current 3*h;
When P=8, the transfer data is 1.75*h, much less than current 7*h; When p->∞, the transfer data is 2*h, of course less than ∞h

Summary

Ring AllReduce can avoid the problem of the master needing to handle the amount of O(h*P) data in the master-worker architecture, which can become a network bottleneck when the number of devices increases to 8 or more.

Best Regards.

For your reference: Optimization of Collective Communication Operations in MPICH .pdf

zhengpeirong commented 1 month ago

I am working on this. Please spend your time on other priorities. @b4rtaz