vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.05k stars 4.33k forks source link

[Performance]: Sampler is too slow? #8040

Open niuzheng168 opened 1 month ago

niuzheng168 commented 1 month ago

Report of performance regression

Background : I am integrating my model to vllm, the model is almost same as llama, but it have a multi-head lm_head, which just something like a for loop in the sample function in LlamaForCausalLM

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        for i in range(num_head):
            next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

But it seems when num_head goes up from 1 to 8, the latency increased significantly. Can easily repro by below test scrfipt

import torch
import time
from vllm import LLM, SamplingParams

torch.random.manual_seed(999)

llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5)
prompts = [
    "Hi my name is",
]

texts = []
start = time.time()
for i in range(10):
    sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=200, top_p=1)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        texts.append(generated_text)
end = time.time()
print(f"Time taken: {end - start:.2f}s")

num_head=1 [00:02<00:00, 2.48s/it, est. speed input: 1.61 toks/s, output: 80.63 toks/s] Time taken: 24.86s

num_head=8 Processed prompts: 100%| [00:03<00:00, 3.03s/it, est. speed input: 1.32 toks/s, output: 66.08 toks/s] Time taken: 30.51s

Almost 25% perf regression. Consider my model is much smaller than llama3 8b, so the perf regression is more obvious. So is this expected? Any idea to mitigate or fix?

youkaichao commented 1 month ago

yes the sampler part is indeed slow, and contribution is welcome to accelerate it!