vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.31k stars 4.75k forks source link

Compute perplexity/logits for the prompt #2364

Open dsmilkov opened 11 months ago

dsmilkov commented 11 months ago

I'd like to use Phi-2 to compute perplexity of the prompts over an entire dataset. Is there an API for this? In the short term, I'm happy to fork https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/phi_1_5.py if you provide pointer on how to do that.

Also happy to later contribute back an API that works for all causal models.

mayiran1999 commented 10 months ago

I have the same need. Have anyone found a possible way to get the logits of the prompt?

caiyuhu commented 10 months ago

I have the same need too, but unfortunately it appears that vLLM has not yet implemented support for it, as evidenced by the following issue discussion. https://github.com/vllm-project/vllm/issues/185

lekhang4497 commented 9 months ago

I think you can use the parameter prompt_logprobs in SamplingParams for this purpose.

1328

dylanbowman314 commented 5 months ago

prompt_logprobs can only return the probabilities for the top <=20 tokens right now, so not applicable for this usecase.

junzhang-zj commented 3 months ago

Is there any progress on this issue at the moment?

Tendo33 commented 2 months ago

same issue here

CodeAsPoetry commented 2 months ago

you can set logprobs=1, prompt_logprobs=1. Then, 屏幕截图 2024-09-18 180643

CodeAsPoetry commented 2 months ago

test prompt >20 , maybe ok 屏幕截图 2024-09-18 181700

Rachum-thu commented 1 day ago

prompt_logprobs can only return the probabilities for the top <=20 tokens right now, so not applicable for this usecase.

Rachum-thu commented 1 day ago

Try this code and it may help you solve the problem:

prefix_list = ['my name is', 'I love']
candidate_list = [[' Hongliang', ' Raymond', ' John'], [' ice cream', ' pizza', ' coding']]

# Initialize sampling parameters
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8, prompt_logprobs=20)

# Process each prefix and corresponding candidates
for prefix, candidates in zip(prefix_list, candidate_list):
    results = {}
    prefix_tokens = llama_tokenizer(prefix)['input_ids']
    prefix_token_length = len(prefix_tokens)

    # Generate prompts and tokenize
    prompts = [prefix + candidate for candidate in candidates]
    prompt_tokens = llama_tokenizer(prompts)
    suffix_tokens_length = [len(token) - prefix_token_length for token in prompt_tokens['input_ids']]

    # Generate outputs
    outputs = llama.generate(prompts, sampling_params)

    # Process each output
    for candidate, output, suffix_len in zip(candidates, outputs, suffix_tokens_length):
        logprobs = output.prompt_logprobs[-suffix_len:]
        target_tokens = prompt_tokens['input_ids'][candidates.index(candidate)][-suffix_len:]

        # Extract probabilities for the target tokens
        log_probs = [logprobs[i][target_tokens[i]] for i in range(suffix_len)]
        results[candidate] = log_probs
    print(results)
    breakpoint()