vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.33k stars 3.15k forks source link

[Feature]: Add efficient interface for evaluating probabilities of fixed prompt-completion pairs #5234

Open xinyangz opened 1 month ago

xinyangz commented 1 month ago

Proposed Feature

Add an efficient interface for generation probabilities on fixed prompt and completion pairs. For example:

# ... load LLM or engine
prompt_completion_pairs = [
    ("1 + 1 = ", "2"),
    ("1 + 1 = ", "3"),
]
prompts, completions = list(zip(*prompt_completion_pairs))
probs = llm.completion_logprobs(prompts=prompts, completions=completions)

Alternatively, the interface could evaluate the probabilities of a fixed prompt with multiple generation options to better leverage prefix caching:

prompt = "1 + 1 = "
completions = ["2", "3", "4"]
probs = llm.completion_logprobs(prompt=prompt, completions=completions)

Currently, there are interfaces in class SamplingParams to return the log probabilities of prompt (prompt_logprobs) and the generated tokens (logprobs). However, they are either inefficient or has incomplete support for this use case.

Motivation

The motivation of this feature comes from LLM evaluations on multiple-choice questions (e.g., MMLU). vLLM is a popular tool adopted by mainstream LLM evaluation frameworks (e.g., lm-evaluation-harness) for this purpose.

Using the following example:

Question: Which of the following is true? (A) ABC (B) DEF The answer is:

Evaluating a base LLM on this question involves calculating the probability on each choice $P_{\text{LLM}}(\text{choice} \mid \text{question})$ and selecting the choice with the highest probability.

Current solution

Currently, lm-evaluation-harness runs two generations and evaluate the full prompt probabilities for this purpose.

question = "1 + 1 = "
choices = ["2", "3", "4"]
prompts = [question + c for c in choices]
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)

Instead of evaluating probabilities on the choices, it evaluates on question + choices and runs through multiple generations because of the limitations in vLLM's user interface.

Efficiency issue with current solution

The issue of using prompt_logprobs is that it is very inefficient on long prompts.

Let's use the following minimal profiling example (profiling.py):

import time

import numpy as np
from vllm import LLM, SamplingParams

n_seqs = 100
vocab_size = 10_000
seq_len = 4000
data = np.random.randint(0, vocab_size, (n_seqs, seq_len)).tolist()

llm = LLM("mistral-community/Mistral-7B-v0.2", max_model_len=8000, gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)

start = time.perf_counter()
outputs = llm.generate(prompts=None, prompt_token_ids=data, sampling_params=sampling_params)
end = time.perf_counter()

print(f"Inference took {end - start:.4f} seconds")

Running the code with vLLM's official docker image:

docker run --gpus all --shm-size=10g --rm -e HF_TOKEN=[token] -v "$(pwd):/app" --entrypoint python3 vllm/vllm-openai:v0.4.3 /app/profiling.py

On a single A100-40G GPU, it runs around 500 seconds with prompt_logprobs=1 and only 27 seconds with no prompt_logprobs. Moreover, we can fit much longer input prompt if we turn it off.

Analysis on the efficiency issue

A quick search of prompt_logprobs takes us to Sampler.forward method in vllm/model_executor/layers/sampler.py.

First, we noticed the shape of logits changes from (1, vocab_size) to (input_len, vocab_size) if we set prompt_logprobs. Second, we found the get_logprobs involves lots of python for loops and CPU-GPU communications.

Potential Changes

I see two ways to fix the efficiency issue with the current approach.

Option 1: Use prompt_logprobs but don't calculate on the full prompt

We could reuse prompt_logprobs but limit the probability calculation to the final few tokens, so we don't have to pass around the large logits array.

Option 2: Use the sampling logprobs but constraint the generation on the choices

Currently, there is a controlled generation interface for OpenAI compatible server, but not for the offline inference.

Alternatives

No response

Additional context

No response

zifeitong commented 1 month ago

I feel you :) Please see my comments inline.

Proposed Feature

Add an efficient interface for generation probabilities on fixed prompt and completion pairs. For example:

# ... load LLM or engine
prompt_completion_pairs = [
    ("1 + 1 = ", "2"),
    ("1 + 1 = ", "3"),
]
prompts, completions = list(zip(*prompt_completion_pairs))
probs = llm.completion_logprobs(prompts=prompts, completions=completions)

Alternatively, the interface could evaluate the probabilities of a fixed prompt with multiple generation options to better leverage prefix caching:

prompt = "1 + 1 = "
completions = ["2", "3", "4"]
probs = llm.completion_logprobs(prompt=prompt, completions=completions)

Currently, there are interfaces in class SamplingParams to return the log probabilities of prompt (prompt_logprobs) and the generated tokens (logprobs). However, they are either inefficient or has incomplete support for this use case.

Motivation

The motivation of this feature comes from LLM evaluations on multiple-choice questions (e.g., MMLU). vLLM is a popular tool adopted by mainstream LLM evaluation frameworks (e.g., lm-evaluation-harness) for this purpose.

Using the following example:

Question: Which of the following is true? (A) ABC (B) DEF The answer is:

Evaluating a base LLM on this question involves calculating the probability on each choice PLLM(choice∣question) and selecting the choice with the highest probability.

Current solution

Currently, lm-evaluation-harness runs two generations and evaluate the full prompt probabilities for this purpose.

question = "1 + 1 = "
choices = ["2", "3", "4"]
prompts = [question + c for c in choices]
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)

Instead of evaluating probabilities on the choices, it evaluates on question + choices and runs through multiple generations because of the limitations in vLLM's user interface.

I think the new prefix caching feature will help. But it would require lm-eval careful reorder the prompts sent to vLLM.

Efficiency issue with current solution

The issue of using prompt_logprobs is that it is very inefficient on long prompts.

Let's use the following minimal profiling example (profiling.py):

import time

import numpy as np
from vllm import LLM, SamplingParams

n_seqs = 100
vocab_size = 10_000
seq_len = 4000
data = np.random.randint(0, vocab_size, (n_seqs, seq_len)).tolist()

llm = LLM("mistral-community/Mistral-7B-v0.2", max_model_len=8000, gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0, max_tokens=1, prompt_logprobs=1)

start = time.perf_counter()
outputs = llm.generate(prompts=None, prompt_token_ids=data, sampling_params=sampling_params)
end = time.perf_counter()

print(f"Inference took {end - start:.4f} seconds")

Running the code with vLLM's official docker image:

docker run --gpus all --shm-size=10g --rm -e HF_TOKEN=[token] -v "$(pwd):/app" --entrypoint python3 vllm/vllm-openai:v0.4.3 /app/profiling.py

On a single A100-40G GPU, it runs around 500 seconds with prompt_logprobs=1 and only 27 seconds with no prompt_logprobs. Moreover, we can fit much longer input prompt if we turn it off.

Can you sync to the HEAD and try with SamplingParams.detokenize = False. You might be hit with https://github.com/vllm-project/vllm/issues/4904.

Analysis on the efficiency issue

A quick search of prompt_logprobs takes us to Sampler.forward method in vllm/model_executor/layers/sampler.py.

First, we noticed the shape of logits changes from (1, vocab_size) to (input_len, vocab_size) if we set prompt_logprobs. Second, we found the get_logprobs involves lots of python for loops and CPU-GPU communications.

Chucked prefill can help with memory usage here. The feature is consider as experimental for now.

Potential Changes

I see two ways to fix the efficiency issue with the current approach.

Option 1: Use prompt_logprobs but don't calculate on the full prompt

We could reuse prompt_logprobs but limit the probability calculation to the final few tokens, so we don't have to pass around the large logits array.

Option 2: Use the sampling logprobs but constraint the generation on the choices

Currently, there is a controlled generation interface for OpenAI compatible server, but not for the offline inference.

Alternatives

Another possible option is that instead of relying on logprobs/perplexity, use generated tokens directly for evaluation. It's the approach used by https://github.com/openai/simple-evals and fits vLLM's design better.

No response

Additional context

No response