Background :
I am integrating my model to vllm, the model is almost same as llama, but it have a multi-head lm_head, which just something like a for loop in the sample function in LlamaForCausalLM
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
for i in range(num_head):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
But it seems when num_head goes up from 1 to 8, the latency increased significantly.
Can easily repro by below test scrfipt
import torch
import time
from vllm import LLM, SamplingParams
torch.random.manual_seed(999)
llm = LLM(model='/home/zhn/g/Meta-Llama-3-8B-Instruct', gpu_memory_utilization=0.5)
prompts = [
"Hi my name is",
]
texts = []
start = time.time()
for i in range(10):
sampling_params = SamplingParams(temperature=0, top_k=1, max_tokens=200, top_p=1)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
texts.append(generated_text)
end = time.time()
print(f"Time taken: {end - start:.2f}s")
num_head=1
[00:02<00:00, 2.48s/it, est. speed input: 1.61 toks/s, output: 80.63 toks/s]
Time taken: 24.86s
num_head=8
Processed prompts: 100%| [00:03<00:00, 3.03s/it, est. speed input: 1.32 toks/s, output: 66.08 toks/s]
Time taken: 30.51s
Almost 25% perf regression.
Consider my model is much smaller than llama3 8b, so the perf regression is more obvious.
So is this expected? Any idea to mitigate or fix?
[X] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Report of performance regression
Background : I am integrating my model to vllm, the model is almost same as llama, but it have a multi-head lm_head, which just something like a for loop in the
sample
function inLlamaForCausalLM
But it seems when num_head goes up from 1 to 8, the latency increased significantly. Can easily repro by below test scrfipt
num_head=1 [00:02<00:00, 2.48s/it, est. speed input: 1.61 toks/s, output: 80.63 toks/s] Time taken: 24.86s
num_head=8 Processed prompts: 100%| [00:03<00:00, 3.03s/it, est. speed input: 1.32 toks/s, output: 66.08 toks/s] Time taken: 30.51s
Almost 25% perf regression. Consider my model is much smaller than llama3 8b, so the perf regression is more obvious. So is this expected? Any idea to mitigate or fix?