microsoft / MInference

To speed up Long-context LLMs' inference, approximate and dynamic sparse calculate the attention, which reduces inference latency by up to 10x for pre-filling on an A100 while maintaining accuracy.
https://aka.ms/MInference
MIT License
681 stars 23 forks source link

[Question]: Why is running MInference/examples/run_vllm.py not as fast as running vllm alone? #43

Open zjjznw123 opened 1 month ago

zjjznw123 commented 1 month ago

Describe the issue

from vllm import LLM, SamplingParams

from minference import MInference

prompts =  [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

prompts = prompts*100

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=10,
)
model_name = "Qwen/Qwen2-7B-Instruct/"

llm = LLM(
    model_name,
    max_num_seqs=1,
    enforce_eager=True,

)

# Patch MInference Module
minference_patch = MInference("vllm", model_name)
llm = minference_patch(llm)

outputs = llm.generate(prompts, sampling_params)

import time
t1 = time.time()

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text

t2 = time.time()

print('minference_time:',t2-t1)

print('=============================================================================')

from vllm import LLM, SamplingParams

from minference import MInference

prompts =  [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

prompts = prompts*100

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=10,
)
model_name = "Qwen/Qwen2-7B-Instruct/"

llm = LLM(
    model_name,
    max_num_seqs=1,
    enforce_eager=True,

)

# Patch MInference Module
#minference_patch = MInference("vllm", model_name)
#llm = minference_patch(llm)

outputs = llm.generate(prompts, sampling_params)

import time
t1 = time.time()

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text

t2 = time.time()

print('vllm_time:',t2-t1)

result: minference_time:0.0003895759582519531s vllm_time:0.0002791881561279297s Why is minference_time greater than vllm_time?

iofu728 commented 1 month ago

Hi @zjjznw123, thanks for your support in MInference.

MInference involves some additional approximations which introduce overhead, making it slower than dense attention for short context sizes. The latency performance breaks even at around 30k tokens, as shown in our end-to-end benchmark: End-to-End Benchmark. The latency can be further optimized by adjusting the sparsity rate of sparse attention, though more aggressive sparse attention in smaller context windows hasn't been extensively tested.

We tested 128k with Qwen2-7B using hf, and its speedup ratio matches the results with LLaMA-3-8B in triton==2.1.0.

➜  MInference git:(main) ✗ python3.9 experiments/benchmarks/benchmark_e2e.py --attn_type minference --context_window 100000 --model_name Qwen/Qwen2-7B-Instruct
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|████████████████████████████████████| 4/4 [00:02<00:00,  1.44it/s]
Patched model for minference..
100000 13.43383584022522
➜  MInference git:(main) ✗ python3.9 experiments/benchmarks/benchmark_e2e.py --attn_type minference_with_dense --context_window 100000 --model_name Qwen/Qwen2-7B-Instruct
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|████████████████████████████████████| 4/4 [00:02<00:00,  1.45it/s]
Patched model for minference..
100000 19.306786155700684