huggingface / blog

Public repo for HF blog posts
https://hf.co/blog
2.4k stars 754 forks source link

Llama3.1 inference memory requirements #2345

Open satojkovic opened 2 months ago

satojkovic commented 2 months ago

https://huggingface.co/blog/llama31#inference-memory-requirements Please tell me about the calculation of inference memory requirements for Llama 3.1 in this post.

The table below shows an excerpt of the KV cache size for FP16. Model Size 1k tokens 16k tokens 128k tokens
8B 0.125 GB 1.95 GB 15.62 GB
70B 0.313 GB 4.88 GB 39.06 GB
405B 0.984 GB 15.38 GB 123.05 GB

I used the formula in this article to do my own calculations. The formula in the article as follows:

image

This shows the size of the KV cache per token, where the first factor of 2 accounts for the K and V matrices. num_layers and num_heads and dim_heads refer to the values in the Llama3 paper.

image

For example, for the 8B model with 16k tokens and 128k tokens, the calculation is as follows and matches the numbers in the table above.

16000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 1.953125
128000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 15.625

However, if we calculate 16k tokens and 128k tokens in the same way for the 405B model, the numbers do not match those in the table above. The calculated values seem to be half of the values in the table.

16000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 7.6904296875
128000 * (2 * 126 * 8 * (16384/128) * 2) / 1024**3
# 61.5234375

Am I misunderstanding something? Or is there another factor that needs to be taken into account for the 405B model?

Also, for 1k tokens, the numbers are slightly different. Is it calculated as 1024 in the table?

1000 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.1220703125
1024 * (2 * 32 * 8 * (4096/32) * 2) / 1024**3
# 0.125

Thank you!

ZeusXuan commented 2 months ago

This chart in llama3 paper has something wrong. The key/value cache head number for 405B model is 16 rather than 8. You can find the answer in this link

satojkovic commented 2 months ago

@ZeusXuan Thank you for the comment! I read the reddit post. Does this mean that the number of KV heads on the 405B model was 16, but has been changed to 8, the same as in the white paper? I found the following link to the commit that fixes it to 8 kv heads. https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-FP8/discussions/15