microsoft / MInference

[NeurIPS'24 Spotlight] To speed up Long-context LLMs' inference, approximate and dynamic sparse calculate the attention, which reduces inference latency by up to 10x for pre-filling on an A100 while maintaining accuracy.
https://aka.ms/MInference
MIT License
810 stars 38 forks source link

[Question]: Discrepancy in Pre-filling Time and Memory Consumption on Single A100 #84

Open lepangdan opened 2 weeks ago

lepangdan commented 2 weeks ago

Describe the issue

I came across your statement in the paper where you mentioned:

"When serving LLaMA-3-8B on a single A100 machine, the model would keep users waiting for 6 minutes to finish the pre-filling stage given a prompt of 300K tokens, and this number increases to 30 minutes for a prompt of 1M tokens."

However, I am also running on a single A100 (80GB) and using Hugging Face's implementation of LLaMA in SDPA mode. With a 50k token context, the pre-fill time is around 2.5 seconds, but when using 100k tokens, I run into an "Out of Memory" issue.

Could you clarify why there is such a significant discrepancy between your results and mine? Is there something I might be missing or misunderstanding?

Thanks for your help!

iofu728 commented 1 week ago

Hi @lepangdan,

Thanks for your question!

  1. First, I apologize for the error in the Introduction section of our paper. The sentence should read: "3 minutes to finish the pre-filling stage given a prompt of 300K tokens," not 6 minutes. You can also verify this in Figure 1(b). We will update the arXiv and NeurIPS versions ASAP. Thank you for pointing this out!

  2. Regarding the TTFT of SDPA being three times faster than what we measured, I suspect the issue might be the absence of torch.cuda.synchronize(). You can follow the script provided in our repository (https://github.com/microsoft/MInference/blob/main/experiments/benchmarks/benchmark_e2e.py) and add it to this line #L125:

    attn_implementation="sdpq", # default is flash_attention_2

    Then run:

    python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000

    The TTFT should be around 7.5 seconds. This result can also be cross-verified with the vLLM implementation:

    Image

  3. Lastly, the original HF implementation does not support very large context windows. As stated in Appendix C.3, we detail the optimization steps we performed. You can use --attn_type minference_with_dense with our optimized implementation or leverage vLLM to achieve longer context windows.

Thanks again for raising these points, and please let me know if you have further questions!

lepangdan commented 5 days ago

Hi, @iofu728 Thanks for your helpful reply.

  1. I ran the command python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 300_000, which resulted in 142s for 300,000 tokens, confirming that 3 minutes is close.
  2. a)Yes, the issue was caused by the absence of torch.cuda.synchronize(). Thank you for your help! b) When I ran python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000, the result was 9.6 seconds, which seems fairly close to the 7.5 seconds you mentioned, from my perspective. c) I noticed that when I run both python experiments/benchmarks/benchmark_e2e.py --attn_type hf --context_window 50_000 or python experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 50_000, the printed model shows that the default attention is in LlamaSdpaAttention (as shown below when I added the print(model) code at the end, on line 137 in https://github.com/microsoft/MInference/blob/7a3e5acaaf0e83105d941a4067f53020ca1eba12/experiments/benchmarks/benchmark_e2e.py), rather than flash_attention_2. Instead, when I explicitly add the argument attn_implementation="flash_attention_2" on line 125, the printed model shows LlamaFlashAttention2. According to the paper, the baseline should be flash_attention_2 , right? But the default setting seems to be sdpa, I'm a bit confused about this . Could you confirm if I might be misunderstanding something?"

default:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)

after adding augment attn_implementation="flash_attention_2":

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)

Looking forward to your reply.

iofu728 commented 4 days ago

Hi @lepangdan,

Thank you for your feedback. The results reported in the paper were obtained using minference_with_dense as it supports longer contexts. In fact, you can also specify the use of flash_attn by setting attn_implementation="flash_attention_2".