b4rtaz / distributed-llama

Tensor parallelism is all you need. Run LLMs on weak devices or make powerful devices even more powerful by distributing the workload and dividing the RAM usage.
MIT License
1.02k stars 68 forks source link

Support nSlices > nKvHeads #70

Open b4rtaz opened 1 month ago

b4rtaz commented 1 month ago

After the attention layers were splitted into all nodes I missed the implications what it introduced.

image

Long story short: to calculate the attention for a single head from the Q output, I need to have the whole head from the K output. For x Q head I need to have whole floor(x / (nHeads / nKvHeads)) K head to calculate the result.

For example Llama 3 8B:

šŸ’” dim: 128
šŸ’” nHeads: 32
šŸ’” nKvHeads: 8

Q head 0  => floor(  0 / ( 32 / 8) ) => K head 0
Q head 1  => floor(  1 / ( 32 / 8) ) => K head 0
Q head 2  => floor(  2 / ( 32 / 8) ) => K head 0
...
Q head 8  => floor(  8 / ( 32 / 8) ) => K head 2
Q head 9  => floor(  9 / ( 32 / 8) ) => K head 2
...
Q head 31 => floor( 31 / ( 32 / 8) ) => K head 7

By this currently is not possible to split nodes to more than nKvHeads nodes.

^ The same problem is with the V layer.


How this could be fixed?

1. Synchronize missing outputs

For nSlices > nKvHeads setups there could be introduced a new synchronization step. This step would synchornize missing Q/V outputs across nodes. Ofc the synchronization is the slowest part of Distributed Llama.

2. Redundancy

The redundancy could be introduces for K/V layers. These layers should be splited with the aligment to headSize. By this there is no synchronization, and redundant amount of calculations seems to be small (headSize - kvDim0).

For example Llama 3 8B:

headSize = dim / nHeads = 128
kvDim = (dim * kvHeads) / nHeads = 1024

nSlices = 16
kvDim0 = kvDim / nSlices = 64
redundancy = 128 - 64 = 64 outputs of K & V

nSlices = 32
kvDim0 = kvDim / nSlices = 32
redundancy = 128 - 32 = 96 outputs of K & V