vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.96k stars 4.71k forks source link

How to use vllm to compute ppl score for input text? #1019

Open yinochaos opened 1 year ago

yinochaos commented 1 year ago

How to use vllm to compute ppl score for input text? I want to use vllm to speed up ppl compute

RanchiZhao commented 1 year ago

same q

renyiyu commented 9 months ago

+1

yuanzhiyong1999 commented 1 month ago

Have you solved it?

ThomasAtlantis commented 2 weeks ago

The original vLLM framework does not support this, because it's a serving engine for downstream applications. However, we can compute the output logits of prefilling tokens by monkey patching.

from vllm.model_executor.layers.logits_processor import _apply_logits_processors

logits_list = []
def forward_hook(module, input, output):
    lm_head, hidden_states, sampling_metadata, *embedding_bias = input
    embedding_bias = embedding_bias[0] if embedding_bias else None
    logits = module._get_logits(hidden_states, lm_head, embedding_bias)
    if logits is not None:
        if module.soft_cap is not None:
            logits = logits / module.soft_cap
            logits = torch.tanh(logits)
            logits = logits * module.soft_cap
        if module.scale != 1.0:
            logits *= module.scale
        logits = _apply_logits_processors(logits, sampling_metadata)
        logits_list.append(logits)
    return output

Suppose we're using the OPT model. Then, we can patch the instance llm by

llm = LLM(model="weights/opt-125m")
llm.llm_engine.model_executor.driver_worker.model_runner.model.logits_processor.register_forward_hook(forward_hook)

Next, we generate outputs and compute ppl on GPU using torcheval.metrics.text.Perplexity.

token_ids = llm.get_tokenizer().encode("Hello world!")
llm.generate(prompt_token_ids=token_ids, sampling_params=sampling_params)

metrics = Perplexity(device="cuda:0")
logits = logits_list[0].unsqueeze(0)
labels = torch.LongTensor(token_ids).unsqueeze(0).cuda()
metrics.update(logits[:, :-1], labels[:, 1:])
print(metrics.compute())

Please refer to LogitsProcessor.forward and register_forward_hook for more information.

hjc3613 commented 1 day ago

The original vLLM framework does not support this, because it's a serving engine for downstream applications. However, we can compute the output logits of prefilling tokens by monkey patching.

from vllm.model_executor.layers.logits_processor import _apply_logits_processors

logits_list = []
def forward_hook(module, input, output):
    lm_head, hidden_states, sampling_metadata, *embedding_bias = input
    embedding_bias = embedding_bias[0] if embedding_bias else None
    logits = module._get_logits(hidden_states, lm_head, embedding_bias)
    if logits is not None:
        if module.soft_cap is not None:
            logits = logits / module.soft_cap
            logits = torch.tanh(logits)
            logits = logits * module.soft_cap
        if module.scale != 1.0:
            logits *= module.scale
        logits = _apply_logits_processors(logits, sampling_metadata)
        logits_list.append(logits)
    return output

Suppose we're using the OPT model. Then, we can patch the instance llm by

llm = LLM(model="weights/opt-125m")
llm.llm_engine.model_executor.driver_worker.model_runner.model.logits_processor.register_forward_hook(forward_hook)

Next, we generate outputs and compute ppl on GPU using torcheval.metrics.text.Perplexity.

token_ids = llm.get_tokenizer().encode("Hello world!")
llm.generate(prompt_token_ids=token_ids, sampling_params=sampling_params)

metrics = Perplexity(device="cuda:0")
logits = logits_list[0].unsqueeze(0)
labels = torch.LongTensor(token_ids).unsqueeze(0).cuda()
metrics.update(logits[:, :-1], labels[:, 1:])
print(metrics.compute())

Please refer to LogitsProcessor.forward and register_forward_hook for more information.

It a surprizing method! I found that when set max_tokens=2, and input multiple prompts eg. [sent1, sent2, sent3], ]after llm.generate() complete, logits_list[0].shape[0] == len(sent1 token) + len(sent2 token) + len(sent3 token), logits_list[1].shape[0]=4, logits_list[2].shape[0] =2. I guess logits_list[0] is prefilling logits, and logits[1] the first token...... . I wonder how the vllm handle multiple prompts? I cannot understand it, I'am newer to vllm. It seem so complex........