Open zhengpeirong opened 5 months ago
Just as a supplement, the figure shows detailed time costs for each task when 4 Raspberry Pis run Llama2-7B-Q40. As you can see, how much time the aforementioned functionk costs. And if you can make these computations parallel according to the 'Mature solution', then the time will decrease nearly linearly with the number of devices increasing. @b4rtaz
Nice measurments! It seems multiheadAtt
is super slow.
@zhengpeirong please check the 0.3.1 version. Now all tasks are executed in parallel so it should be a bit better.
@b4rtaz The 'qkv' has been reverted. Do you plan to deal with this issue? Not only the 'MulHead' costs time, but also the 'Finalize' costs a big portion of time.
@zhengpeirong yes I know. The qkv
seems be quite good optimalized if you look at the rest layers. Still the qkv
may be improved in this way as you suggested in the first post. I didn't have time to read it yet.
With the finalize
layer is that problem, the output of this layer is large (vocabSize
) and I think it's not a good idea to synchronize it. But maybe it could be optimised in this way that a worker would use the sampler on slice of own output, then the root node could merge it somehow. Different samplers would require a different logic for merging, but it looks doable (for example sample_argmax
looks super easy).
Yes, I want to keep working on this project. More hands are welcome. :-)
@b4rtaz Thanks for your persistence and endeavor.
qkv
can be optimized, and all you need to read is the "3. Model Parallel Transformers" of the paper "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism".finalize
can be optimized as your design.If you combine all those mechanisms, the non-parallel functions will be optimized! Here is the draft workflow:
TransformerArch buildLlama2Arch(TransformerSpec* spec) {
TransformerArch a;
// Inference
a.I(sendPoke, TASK_TYPE_TRANSFER);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE); // Combine the existing llamaRmsAtt and llamaRmsAttNorm
a.I(llamaQuantizeRmsAtt, TASK_TYPE_INFERENCE); // Quantization
a.I(llamaSyncRmsAtt, TASK_TYPE_TRANSFER); // Sending
a.I(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.I(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.I(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.I(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaSyncAtt, TASK_TYPE_TRANSFER); // First communication time-consuming
a.I(llamaDequantizeAtt, TASK_TYPE_INFERENCE);
a.I(llamaMergeAtt, TASK_TYPE_INFERENCE); // Merge all attention matrices
a.I(llamaRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaRmfFfnNorm, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeRmfFfn, TASK_TYPE_INFERENCE);
a.I(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.I(llamaFfn, TASK_TYPE_INFERENCE); // Compute SwiGLU activation
a.I(llamaFfn2, TASK_TYPE_INFERENCE); // Compute the second FFN
a.I(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaSyncFfn2, TASK_TYPE_TRANSFER); // Second communication time-consuming
a.I(llamaDequantizeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaMergeFfn2, TASK_TYPE_INFERENCE);
a.I(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.I(llamaRmsFinal, TASK_TYPE_INFERENCE);
a.I(llamaRmsFinalNorm, TASK_TYPE_INFERENCE);
a.I(llamaLogits, TASK_TYPE_INFERENCE);
a.I(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaSyncLogits, TASK_TYPE_TRANSFER);
a.I(llamaDequantizeLogits, TASK_TYPE_INFERENCE);
a.I(llamaMergeLogits, TASK_TYPE_INFERENCE);
// Worker
for (int i = 0; i < spec->nLayers; i++) {
a.W(llamaSyncRmsAtt, TASK_TYPE_TRANSFER);
a.W(llamaQkv, TASK_TYPE_INFERENCE); // Compute Q K V
a.W(llamaMultiheadAtt, TASK_TYPE_INFERENCE); // Merge kv-cache, add RoPE encoding, compute a part of multi-head attention locally
a.W(llamaAttOutput, TASK_TYPE_INFERENCE); // Worker computes W_O matrix
a.W(llamaQuantizeAtt, TASK_TYPE_INFERENCE);
a.W(llamaSyncAtt, TASK_TYPE_TRANSFER);
a.W(llamaSyncRmfFfn, TASK_TYPE_TRANSFER);
a.W(llamaFfn, TASK_TYPE_INFERENCE);
a.W(llamaFfn2, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeFfn2, TASK_TYPE_INFERENCE);
a.W(llamaSyncFfn2, TASK_TYPE_TRANSFER);
a.W(llamaNextBlock, TASK_TYPE_INFERENCE);
}
a.W(llamaLogits, TASK_TYPE_INFERENCE);
a.W(llamaQuantizeLogits, TASK_TYPE_INFERENCE);
a.W(llamaSyncLogits, TASK_TYPE_TRANSFER);
return a;
}
I hope this repo can catch up with the state-of-the-art algorithm as soon as possible~~
The optimized result will be only 72% of the original generated time!!! It's 1.39x acceleration than this version. I have roughly computed the optimized result. Specifically, the main transfer time only happens twice and the workload for the root node is divided among 4 workers.
@zhengpeirong this is just a guess, have you proved that by any implementation?
Currently I noticed a problem with the rope layer, it's not easy to split it, because to calculate the output of this layer we need:
Output digits <0; kvDim) = q & k
Output digits <kvDim; dim) = q
So the current implementation divides q and k outputs into equal parts (<0; s)
, <s; s+1)
...). This won't work for the rope, because the first node would require a bit of k
output from the second node etc...
I see some posibility to solve it but it is much complex that I thought. I probably should split Q & K layers into many small columns (width=2), and assign columns to nodes.
worker 1: output digits 1, 2, 6, 8, ... (n, n + 1)
worker 2: output digits 2, 4, 9, 10, ... (n + 2, n + 3)
The paper that you linked doesn't have any part about the rope. So probably we have a different case here.
@zhengpeirong this is just a guess, have you proved that by any implementation?
Currently I noticed a problem with the rope layer, it's not easy to split it, because to calculate the output of this layer we need:
Output digits <0; kvDim) = q & k Output digits <kvDim; dim) = q
So the current implementation divides q and k outputs into equal parts (
<0; s)
,<s; s+1)
...). This won't work for the rope, because the first node would require a bit ofk
output from the second node etc...I see some posibility to solve it but it is much complex that I thought. I probably should split Q & K layers into many small columns (width=2), and assign columns to nodes.
worker 1: output digits 1, 2, 6, 8, ... (n, n + 1) worker 2: output digits 2, 4, 9, 10, ... (n + 2, n + 3)
The paper that you linked doesn't have any part about the rope. So probably we have a different case here.
The issue you are currently facing lies in separately calculating the QKV matrices, which are split according to the hidden dimension
. Therefore, splitting RoPE cannot be easily implemented.
However, _if tensor parallelism is supported, the splitting is performed along the num_head dimension
, dividing the attention heads across different devices_. This is independent of the hidden dimension
dimension where RoPE resides, thus avoiding the problem you encountered.
In summary, the RoPE computation and the multi-head attention computation are orthogonal, operating on different dimensions: the former on the hidden dimension
and the latter on the num_head dimension
. The RoPE part can be easily completed separately on each device.
I needed a bit of time to notice my thinking error. After all the rope layer is splitted out to the root node and workers. π Tested it with 1, 2 and 4 nodes and the macbeth test generates the same output on different topologies *.
* The macbeth test doesn't work with the buffer quantization (it generates a different output), because now the RoPE is applied before the transfer quantization. Previously, it was applied after the transfer dequantization. I expect this affects the perplexity somehow. Probably this will be resolved if the llamaMultiheadAttJoin
function will be also splitted out.
Now all nodes have the RoPE cache, and the size of the cache is different for all nodes. This may be a bit optimized, but "so far so good".
root node:
π ropeCache: 8192 kB
1 worker:
π ropeCache: 28672 kB
2 worker:
π ropeCache: 20480 kB
3 worker:
π ropeCache: 26624 kB
Next, I'll try to split out the llamaMultiheadAttJoin
function.
Finally I splitted out the multihead layer into all nodes (still not merged, I need to fix mixtral & grok architectures). First measurments:
Model: Llama 3 8B Q40 Buffer: Q80 Setup: 4 x Raspberry Pi 5 8GB + TP-Link LS1008G Switch
Devices | 0.3.0 | This PR | Percentage change |
---|---|---|---|
2 x Raspberry Pi 5 | S 646 kB + R 476 kB = 1122 kB | S 578 kB + R 442 kB = 1020 kB | -9.09% |
4 x Raspberry Pi 5 | S 2295 kB + R 714 kB = 3009 kB | S 2193 kB + R 663 kB = 2856 kB | -5.08% |
Devices | 0.3.0 | This PR | Percentage change | |
---|---|---|---|---|
2 x Raspberry Pi 5 | Avg generation time | 444.27 ms | 381.81 ms | |
Avg inference time | 362.73 | 349.94 ms | -3.53% | |
Avg transfer time | 80.11 ms | 30.31 ms* | ||
4 x Raspberry Pi 5 | Avg generation time | 331.47 ms | 359.44 ms | |
Avg inference time | 267.62 ms | 258.00 ms | -3.59% | |
Avg transfer time | 62.34 ms | 99.69 ms |
* I think the used switch is completely non-deterministic, it achieves a random speed at different times. So I recommend to compare only the avg inference time.
It looks like that gave a tiny speed up (maybe 3%). I expected a bit more. π€
Update: I changed the implementation a bit, now there is no synchronization between llamaQuantizeMultiheadAtt
and llamaAtt
. So basically now we have the state-of-the-art parallelism of attention layers. π
Devices | 0.3.0 | PR v2 | Percentage change |
---|---|---|---|
2 devices | S 646 kB + R 476 kB = 1122 kB | S 510 kB + R 442 kB = 952 kB | -15.15% |
4 devices | S 2295 kB + R 714 kB = 3009 kB | S 1887 kB + R 867 kB = 2754 kB | -8.47% |
8 devices | S 5771 kB + R 833 kB = 6604 kB | S 4819 kB + R 1487 kB = 6306 kB | -4.51% |
The final state of the attention synchronization looks like this for a single block:
root --- xb ---> node
root <-- xbv ---- node
merge att
The previous implementation:
root --- xb --> node
root <-- q ---- node
root <-- k ---- node
root <-- v ---- node
root --- xb ---> node
root <-- xb2 --- node
merge att
@b4rtaz πYou have completed the sota tensor parallel for Attention Layer!!! Moreover, continuing our discussion before, there are still two optimizations that can be done:
Computation:
The last layer(Finalize
) occupies 11% of the total time. It can be decomposed as parallel computing + synchronization(merge).
Then, 11%
can be reduced to 2.75%+synchronization
.
Communication:
Currently, there are 3 main synchronization functions in one transformer block, the attention layer takes 1, and the FFN layer takes 2. You are using 2 All-gather operations(syncFfnA
and syncFfn2
) in the FFN. It can be optimized as 1 All-Reduce operation syncFfn
.
The slicing approach is explained in detail by the PyTorch tutorial:
"feed_foward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
"feed_forward.w3": ColwiseParallel(),
Then 4.76%
time can be reduced.
In summary, at most 12% acceleration can be made upon the current version. When the worker number increases, 4 workers in this issue, this acceleration would enjoy more parallelism.
@b4rtaz Just so your reference, this code implements the FFN layer of llama with Tensor Parallel acceleration.
In summary, the only 2 dimensions Tensor Parallel divides for the Attention layer is the head
dimension, while for the MLP layer, it's the intermediate hidden
dimension.
@zhengpeirong it seems after I adjusted mlp layers to your suggestion the transfer has dropped by ~40% per token. π€―
Devices | 0.5.0 | 0.7.1 | Percentage change |
---|---|---|---|
2 devices | S 510 kB + R 442 kB = 952 kB | S 272 kB + R 272 kB = 544 kB | -42.8% |
4 devices | S 1887 kB + R 867 kB = 2754 kB | S 816 kB + R 816 kB = 1632 kB | -40.7% |
Later I'll check the impact on the generation time.
Where you see the generation time data? π€
@zhengpeirong it seems after I adjusted mlp layers to your suggestion the transfer has dropped by ~40% per token. π€―
Devices 0.5.0 PR Percentage change 2 devices S 510 kB + R 442 kB = 952 kB S 272 kB + R 272 kB = 544 kB -42.8% 4 devices S 1887 kB + R 867 kB = 2754 kB S 816 kB + R 816 kB = 816 kB -40.7% Later I'll check the impact on the generation time.
οΌ272 kB οΌ is compatible with the theory analysis.
272/32/4=2.125
Except for transfer data for the embedding layers, we can treat this as 2.
This means there are 2 times All-Reduce transfers in a single Transformer block.
And the S=R
is exactly what All-Reduce will show.
Congratulations on finishing this feature suggestion!
nTokens = 90, buffer = Q80
4 x Rasperry Pi 5 8GB
Version | Avg tokens / second | Avg generation time | Avg inference time | Avg transfer time |
---|---|---|---|---|
0.7.1 | 4.08 | 245.08 ms | 169.33 ms | 75.34 ms |
0.7.0 | 3.90 | 256.23 ms | 168.77 ms | 87.12 ms |
0.6.0 | 4.24 | 235.69 ms | 143.44 ms | 91.77 ms |
2 x Rasperry Pi 5 8GB
Version | Avg tokens / second | Avg generation time | Avg inference time | Avg transfer time |
---|---|---|---|---|
0.7.1 | 3.07 | 325.46 ms | 269.04 ms | 56.39 ms |
0.7.0 | 2.91 | 343.44 ms | 266.51 ms | 76.87 ms |
0.6.0 | 3.06 | 327.17 ms | 249.80 ms | 77.28 ms |
nTokens = 128, buffer = Q80
2 x Rasperry Pi 5 8GB
Version | Avg tokens / second | Avg generation time | Avg inference time | Avg transfer time |
---|---|---|---|---|
0.7.1 | 16.86 | 59.31 ms | 50.37 ms | 8.58 ms |
0.7.0 | 15.17 | 65.93 ms | 52.07 ms | 13.45 ms |
nTokens = 90, buffer = Q80
2 x AMD EPYC 7402P 24-Core Processor
Version | Avg tokens / second | Avg generation time | Avg inference time | Avg transfer time |
---|---|---|---|---|
0.7.1 | 13.04 | 76.67 ms | 45.33 ms | 30.93 ms |
0.7.0 | 12.79 | 78.21 ms | 46.30 ms | 31.49 ms |
0.6.0 | 12.55 | 79.71 ms | 47.08 ms | 32.22 ms |
In all cases the average transfer time has dropped. What is interesting the non-blocking sockets reduce the speed on Raspberry Pi but on a strong machine not. Maybe this mode should be optional.
In all cases the average transfer time has dropped. What is interesting the non-blocking sockets reduce the speed on Raspberry Pi but on a strong machine not. Maybe this mode should be optional.
Do you mean blocking sockets reduces the speed?
Could you try 8 x Raspberry Pi? Since there are obvious transfer delays for 8 devices, I am curious whether it's because of network traffic congestion.
BTW, I think it's time to update the README.md
with the newest generation time for people new here.
Do you mean blocking sockets reduces the speed?
No. The non-blocking sockets I think. From the 0.6.1 Distributed Llama has enabled non-blocking sockets for root <> node communciation.
Could you try 8 x Raspberry Pi?
Unfortunelty I don't have 8 devices anymore. I have only 4 x Raspberry Pi 5 8GB.
BTW, I think it's time to update the README.md with the newest generation time for people new here.
You're right. I'll do it soon.
Do you mean blocking sockets reduces the speed?
No. The non-blocking sockets I think. From the 0.6.1 Distributed Llama has enabled non-blocking sockets for root <> node communciation.'
The non-blocking sockets will make the CPU do other jobs instead of waiting. But what's the logical connection between non-blocking and increased inference time?
Could you try 8 x Raspberry Pi?
Unfortunelty I don't have 8 devices anymore. I have only 4 x Raspberry Pi 5 8GB.
In this discussion, you are invited to conduct experiments with more devices and find what number of devices is the best choice, then present it in README.
If the dllama
can support any dual number of devices, more scenarios can be supported since there is a big gap between 8 and 16 and 32 devices.
The non-blocking sockets will make the CPU do other jobs instead of waiting. But what's the logical connection between non-blocking and increased inference time?
I think this problem appears only on slow devices like Raspberry Pi. I cannot explain it but you can see the drop in the speed 0.6.0 -> 0.7.0. This was only a minor change between these versions.
Maybe we need more tests.
Dear Author,
Your contribution is critical for the open-source community. The distributed-llama repo has implemented tensor parallelism from scratch. And the result is amazingly significant. However, there are still improvements that could be made. Because of my poor coding ability, not able to make improvements myself, I hope you can look at my suggestions below.
Challenge: root node's special task and synchronization
When I run the repo version '0.1.0', I find that the
softmax
operations inMultiHead
are conducted on the root node only. This operation costs a significant portion of the total time. Second, thesynFfnA
andsynFfn2
functions also cost a lot of time.Mature solutions
In fact, these challenges have been found in this paper: https://arxiv.org/abs/1909.08053. Its solution is shown in the image:
It conducts attention mechanism(softmax) on every worker. Second, the matrix segmentation direction is using column segment and row segment in two consecutive matrices, thus reducing to one synchronization operation instead of two.
If you are willing to make further improvements to the repo, the following is the mature solution for every component of
llama2
using tensor parallelism and sequence parallelism. https://pytorch.org/tutorials/intermediate/TP_tutorial.html However, it's implemented in Python, and you will be the first one to implement the solution in C++.Thanks for your contribution!!! Best Regards