Open xinyangz opened 5 months ago
I feel you :) Please see my comments inline.
Proposed Feature
Add an efficient interface for generation probabilities on fixed prompt and completion pairs. For example:
# ... load LLM or engine prompt_completion_pairs = [ ("1 + 1 = ", "2"), ("1 + 1 = ", "3"), ] prompts, completions = list(zip(*prompt_completion_pairs)) probs = llm.completion_logprobs(prompts=prompts, completions=completions)
Alternatively, the interface could evaluate the probabilities of a fixed prompt with multiple generation options to better leverage prefix caching:
prompt = "1 + 1 = " completions = ["2", "3", "4"] probs = llm.completion_logprobs(prompt=prompt, completions=completions)
Currently, there are interfaces in class
SamplingParams
to return the log probabilities of prompt (prompt_logprobs
) and the generated tokens (logprobs
). However, they are either inefficient or has incomplete support for this use case.Motivation
The motivation of this feature comes from LLM evaluations on multiple-choice questions (e.g., MMLU). vLLM is a popular tool adopted by mainstream LLM evaluation frameworks (e.g., lm-evaluation-harness) for this purpose.
Using the following example:
Question: Which of the following is true? (A) ABC (B) DEF The answer is:
Evaluating a base LLM on this question involves calculating the probability on each choice PLLM(choice∣question) and selecting the choice with the highest probability.
Current solution
Currently, lm-evaluation-harness runs two generations and evaluate the full prompt probabilities for this purpose.
question = "1 + 1 = " choices = ["2", "3", "4"] prompts = [question + c for c in choices] sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1) outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)
Instead of evaluating probabilities on the choices, it evaluates on question + choices and runs through multiple generations because of the limitations in vLLM's user interface.
I think the new prefix caching feature will help. But it would require lm-eval careful reorder the prompts sent to vLLM.
Efficiency issue with current solution
The issue of using
prompt_logprobs
is that it is very inefficient on long prompts.Let's use the following minimal profiling example (profiling.py):
import time import numpy as np from vllm import LLM, SamplingParams n_seqs = 100 vocab_size = 10_000 seq_len = 4000 data = np.random.randint(0, vocab_size, (n_seqs, seq_len)).tolist() llm = LLM("mistral-community/Mistral-7B-v0.2", max_model_len=8000, gpu_memory_utilization=0.6) sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1) start = time.perf_counter() outputs = llm.generate(prompts=None, prompt_token_ids=data, sampling_params=sampling_params) end = time.perf_counter() print(f"Inference took {end - start:.4f} seconds")
Running the code with vLLM's official docker image:
docker run --gpus all --shm-size=10g --rm -e HF_TOKEN=[token] -v "$(pwd):/app" --entrypoint python3 vllm/vllm-openai:v0.4.3 /app/profiling.py
On a single A100-40G GPU, it runs around 500 seconds with
prompt_logprobs=1
and only 27 seconds with no prompt_logprobs. Moreover, we can fit much longer input prompt if we turn it off.
Can you sync to the HEAD and try with SamplingParams.detokenize = False
. You might be hit with https://github.com/vllm-project/vllm/issues/4904.
Analysis on the efficiency issue
A quick search of
prompt_logprobs
takes us toSampler.forward
method in vllm/model_executor/layers/sampler.py.First, we noticed the shape of logits changes from (1, vocab_size) to (input_len, vocab_size) if we set prompt_logprobs. Second, we found the get_logprobs involves lots of python for loops and CPU-GPU communications.
Chucked prefill
can help with memory usage here. The feature is consider as experimental for now.
Potential Changes
I see two ways to fix the efficiency issue with the current approach.
Option 1: Use prompt_logprobs but don't calculate on the full prompt
We could reuse
prompt_logprobs
but limit the probability calculation to the final few tokens, so we don't have to pass around the large logits array.Option 2: Use the sampling logprobs but constraint the generation on the choices
Currently, there is a controlled generation interface for OpenAI compatible server, but not for the offline inference.
Alternatives
Another possible option is that instead of relying on logprobs/perplexity, use generated tokens directly for evaluation. It's the approach used by https://github.com/openai/simple-evals and fits vLLM's design better.
No response
Additional context
No response
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Proposed Feature
Add an efficient interface for generation probabilities on fixed prompt and completion pairs. For example:
Alternatively, the interface could evaluate the probabilities of a fixed prompt with multiple generation options to better leverage prefix caching:
Currently, there are interfaces in class
SamplingParams
to return the log probabilities of prompt (prompt_logprobs
) and the generated tokens (logprobs
). However, they are either inefficient or has incomplete support for this use case.Motivation
The motivation of this feature comes from LLM evaluations on multiple-choice questions (e.g., MMLU). vLLM is a popular tool adopted by mainstream LLM evaluation frameworks (e.g., lm-evaluation-harness) for this purpose.
Using the following example:
Evaluating a base LLM on this question involves calculating the probability on each choice $P_{\text{LLM}}(\text{choice} \mid \text{question})$ and selecting the choice with the highest probability.
Current solution
Currently, lm-evaluation-harness runs two generations and evaluate the full prompt probabilities for this purpose.
Instead of evaluating probabilities on the choices, it evaluates on question + choices and runs through multiple generations because of the limitations in vLLM's user interface.
Efficiency issue with current solution
The issue of using
prompt_logprobs
is that it is very inefficient on long prompts.Let's use the following minimal profiling example (profiling.py):
Running the code with vLLM's official docker image:
On a single A100-40G GPU, it runs around 500 seconds with
prompt_logprobs=1
and only 27 seconds with no prompt_logprobs. Moreover, we can fit much longer input prompt if we turn it off.Analysis on the efficiency issue
A quick search of
prompt_logprobs
takes us toSampler.forward
method in vllm/model_executor/layers/sampler.py.First, we noticed the shape of logits changes from (1, vocab_size) to (input_len, vocab_size) if we set prompt_logprobs. Second, we found the get_logprobs involves lots of python for loops and CPU-GPU communications.
Potential Changes
I see two ways to fix the efficiency issue with the current approach.
Option 1: Use prompt_logprobs but don't calculate on the full prompt
We could reuse
prompt_logprobs
but limit the probability calculation to the final few tokens, so we don't have to pass around the large logits array.Option 2: Use the sampling logprobs but constraint the generation on the choices
Currently, there is a controlled generation interface for OpenAI compatible server, but not for the offline inference.
Alternatives
No response
Additional context
No response