Open satojkovic opened 2 months ago
This chart in llama3 paper has something wrong. The key/value cache head number for 405B model is 16 rather than 8. You can find the answer in this link
@ZeusXuan Thank you for the comment! I read the reddit post. Does this mean that the number of KV heads on the 405B model was 16, but has been changed to 8, the same as in the white paper? I found the following link to the commit that fixes it to 8 kv heads. https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-FP8/discussions/15
https://huggingface.co/blog/llama31#inference-memory-requirements Please tell me about the calculation of inference memory requirements for Llama 3.1 in this post.
I used the formula in this article to do my own calculations. The formula in the article as follows:
This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices. num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.
For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.
However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.
Am I misunderstanding something? Or is there another factor that needs to be taken into account for the 405B model?
Also, for 1k tokens, the numbers are slightly different. Is it calculated as 1024 in the table?
Thank you!