Open lepangdan opened 2 weeks ago
Hi @lepangdan,
Thanks for your question!
First, I apologize for the error in the Introduction section of our paper. The sentence should read: "3 minutes to finish the pre-filling stage given a prompt of 300K tokens," not 6 minutes. You can also verify this in Figure 1(b). We will update the arXiv and NeurIPS versions ASAP. Thank you for pointing this out!
Regarding the TTFT of SDPA being three times faster than what we measured, I suspect the issue might be the absence of torch.cuda.synchronize()
. You can follow the script provided in our repository (https://github.com/microsoft/MInference/blob/main/experiments/benchmarks/benchmark_e2e.py) and add it to this line #L125:
attn_implementation="sdpq", # default is flash_attention_2
Then run:
python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000
The TTFT should be around 7.5 seconds. This result can also be cross-verified with the vLLM implementation:
Lastly, the original HF implementation does not support very large context windows. As stated in Appendix C.3, we detail the optimization steps we performed. You can use --attn_type minference_with_dense
with our optimized implementation or leverage vLLM to achieve longer context windows.
Thanks again for raising these points, and please let me know if you have further questions!
Hi, @iofu728 Thanks for your helpful reply.
python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 300_000
, which resulted in 142s for 300,000 tokens, confirming that 3 minutes is close. torch.cuda.synchronize()
. Thank you for your help!
b) When I ran python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000
, the result was 9.6 seconds, which seems fairly close to the 7.5 seconds you mentioned, from my perspective.
c) I noticed that when I run both python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000
or python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 50_000
, the printed model shows that the default attention is in LlamaSdpaAttention (as shown below when I added the print(model)
code at the end, on line 137 in https://github.com/microsoft/MInference/blob/7a3e5acaaf0e83105d941a4067f53020ca1eba12/experiments/benchmarks/benchmark_e2e.py), rather than flash_attention_2. Instead, when I explicitly add the argument attn_implementation="flash_attention_2"
on line 125, the printed model shows LlamaFlashAttention2
. According to the paper, the baseline should be flash_attention_2 , right? But the default setting seems to be sdpa, I'm a bit confused about this . Could you confirm if I might be misunderstanding something?"default:
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
after adding augment attn_implementation="flash_attention_2"
:
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaFlashAttention2(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
Looking forward to your reply.
Hi @lepangdan,
Thank you for your feedback. The results reported in the paper were obtained using minference_with_dense
as it supports longer contexts. In fact, you can also specify the use of flash_attn
by setting attn_implementation="flash_attention_2"
.
Describe the issue
I came across your statement in the paper where you mentioned:
"When serving LLaMA-3-8B on a single A100 machine, the model would keep users waiting for 6 minutes to finish the pre-filling stage given a prompt of 300K tokens, and this number increases to 30 minutes for a prompt of 1M tokens."
However, I am also running on a single A100 (80GB) and using Hugging Face's implementation of LLaMA in SDPA mode. With a 50k token context, the pre-fill time is around 2.5 seconds, but when using 100k tokens, I run into an "Out of Memory" issue.
Could you clarify why there is such a significant discrepancy between your results and mine? Is there something I might be missing or misunderstanding?
Thanks for your help!