vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.25k stars 4.58k forks source link

[Misc]: hidden states using vllm #3594

Open ra-MANUJ-an opened 7 months ago

ra-MANUJ-an commented 7 months ago

Anything you want to discuss about vllm.

Following is a little piece of code to extract embeddings from a certain layer of LLM:

def process_row(prompt: str, model, tokenizer, layers_to_use: list, remove_period: bool):
    """
    Processes a row of data and returns the embeddings.
    """
    if remove_period:
        prompt = prompt.rstrip(". ")
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(inputs.input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1, min_new_tokens=1)
    embeddings = {}
    for layer in layers_to_use:
        last_hidden_state = outputs.hidden_states[0][layer][0][-1]
        embeddings[layer] = [last_hidden_state.numpy().tolist()]
    return embeddings

It's pretty standard way, but it's pretty slow. Is there any way to use vllm to make it faster without needing to call generate function everytime? I've tried batching, but it's slow too. Any help is appreciated!

One way to get last hidden state values using vllm is as follows:

from vllm import LLM, SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceData, 
                           SequenceGroupMetadata, SequenceStatus)
from transformers import LlamaModel, LlamaTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata

llm = LLM(model=path_to_llama2)

# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = llm.llm_engine.workers[0].model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = llm.llm_engine.workers[0].scheduler_config.max_num_batched_tokens
max_num_seqs = llm.llm_engine.workers[0].scheduler_config.max_num_seqs
prompt = train[0]
prompt_token_ids = llm.llm_engine.tokenizer.encode(prompt) #[2, 100, 524, 10]
seqs = []

group_id = 1
seq_data = SequenceData(prompt_token_ids)
seq = SequenceGroupMetadata(
    request_id=str(group_id),
    is_prompt=True,
    seq_data={group_id: seq_data},
    sampling_params=sampling_params,
    block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = llm.llm_engine.workers[0]._prepare_inputs(
    seqs)
prompt_len = len(seq_data.prompt_token_ids)
input_tokens = input_tokens[:prompt_len]
input_positions = input_positions[:prompt_len]
# Execute the model.
num_layers = llm.llm_engine.workers[0].model_config.get_num_layers(llm.llm_engine.workers[0].parallel_config)
tempOut = llm.llm_engine.workers[0].model.model(
    input_ids=input_tokens,
    positions=input_positions,
    kv_caches=[(None, None)] * num_layers,
    input_metadata=input_metadata,
    cache_events=None,
)
print(tempOut.size())

but this doesn't get me with all the hidden state embeddings (of all layers). Is there any other way to get such values in a faster manner?

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!