run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.57k stars 5.23k forks source link

[Question]: How to use open source model served with vLLM with Query Pipeline in llama index? #15972

Open RishabhSingh021 opened 1 month ago

RishabhSingh021 commented 1 month ago

Question Validation

Question

So I tried running a model like this llm = VllmServer( api_url="http://localhost:8000/v1/completions", model= "TheBloke/zephyr-7B-beta-AWQ", max_new_tokens=256, temperature=0.1, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, ) while running the model with vLLM serve and when I try to run the query pipeline with a summarizer that uses this model like this retriever = hybrid_index.as_retriever(similarity_top_k=6) summarizer = TreeSummarize(llm=llm)

pipeline = QueryPipeline( modules={ "input": InputComponent(), "retriever": retriever, "summarizer": summarizer, }, verbose=True )

pipeline.add_link("input", "retriever") pipeline.add_link("input", "summarizer", dest_key="query_str") pipeline.add_link("retriever", "summarizer", dest_key="nodes")

user_inputs = [ "What is Polarin DCI Wave?", "Is PAYG available for Ports?", "Is PAYG available for Virtual Router?", "Is virtual appliance available with PAYG plan?" ]

for msg in user_inputs:

# Run pipeline
response = pipeline.run(
    input=msg
)
print(f"Response of Pipeline: {response}")
print(str(response))

I get an error like this Traceback (most recent call last): File "/home/pin/queryPipeline.py", line 128, in response = pipeline.run( File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/querypipeline/query.py", line 413, in run outputs, = self._run( File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/query_pipeline/query.py", line 780, in _run result_output_dicts, intermediate_dicts = self._run_multi( File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/query_pipeline/query.py", line 937, in _run_multi output_dict = module.run_component(module_input) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/base/query_pipeline/query.py", line 203, in run_component component_outputs = self._run_component(kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/response_synthesizers/base.py", line 363, in _run_component output = self.synthesizer.synthesize(kwargs["query_str"], kwargs["nodes"]) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(*args, kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/response_synthesizers/base.py", line 241, in synthesize response_str = self.get_response( File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/response_synthesizers/tree_summarize.py", line 162, in get_response response = self._llm.predict( File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(args, kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/llms/llm.py", line 579, in predict response = self.complete(formatted_prompt, formatted=True) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/instrumentation/dispatcher.py", line 265, in wrapper result = func(*args, *kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/core/llms/callbacks.py", line 431, in wrapped_llm_predict f_return_val = f(_self, args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/llama_index/llms/vllm/base.py", line 436, in complete output = get_response(response) File "/usr/local/lib/python3.10/dist-packages/llama_index/llms/vllm/utils.py", line 9, in get_response return data["text"] KeyError: 'text'. I need to know how to integrate this vLLM served model in the given query pipeline.

dosubot[bot] commented 1 month ago

Hey @RishabhSingh021! 👋 I'm here to help you with any bugs, questions, or contributions you have. Let's squash those bugs together!

I found a similar unsolved discussion that might be relevant to your issue: Unable to perform LLM inference when using VllmServer [1].

To integrate a vLLM served model in the given query pipeline using LlamaIndex, you can use the VllmServer class provided in the llama_index.llms.vllm module. Below is an example of how to set up and use the VllmServer class:

import os
from llama_index.llms.vllm import VllmServer

# Define functions to format messages and completions for the model
def messages_to_prompt(messages):
    prompt = "\n".join([str(x) for x in messages])
    return f"<s>[INST] {prompt} [/INST] </s>\n"

def completion_to_prompt(completion):
    return f"<s>[INST] {completion} [/INST] </s>\n"

# Initialize the VllmServer with the appropriate parameters
llm = VllmServer(
    api_url="http://localhost:8000/v1/completions",
    model="TheBloke/zephyr-7B-beta-AWQ",
    max_new_tokens=256,
    temperature=0.1,
    messages_to_prompt=messages_to_prompt,
    completion_to_prompt=completion_to_prompt,
)

# Use the complete method to get a response from the model
response = llm.complete("What is a black hole?")
print(response.text)

In your case, you are using the TreeSummarize class with the QueryPipeline. Here is how you can set it up:

from llama_index import QueryPipeline, InputComponent, TreeSummarize
from llama_index.llms.vllm import VllmServer

# Define functions to format messages and completions for the model
def messages_to_prompt(messages):
    prompt = "\n".join([str(x) for x in messages])
    return f"<s>[INST] {prompt} [/INST] </s>\n"

def completion_to_prompt(completion):
    return f"<s>[INST] {completion} [/INST] </s>\n"

# Initialize the VllmServer with the appropriate parameters
llm = VllmServer(
    api_url="http://localhost:8000/v1/completions",
    model="TheBloke/zephyr-7B-beta-AWQ",
    max_new_tokens=256,
    temperature=0.1,
    messages_to_prompt=messages_to_prompt,
    completion_to_prompt=completion_to_prompt,
)

# Set up the summarizer and query pipeline
summarizer = TreeSummarize(llm=llm)
retriever = hybrid_index.as_retriever(similarity_top_k=6)

pipeline = QueryPipeline(
    modules={
        "input": InputComponent(),
        "retriever": retriever,
        "summarizer": summarizer,
    },
    verbose=True
)

pipeline.add_link("input", "retriever")
pipeline.add_link("input", "summarizer", dest_key="query_str")
pipeline.add_link("retriever", "summarizer", dest_key="nodes")

user_inputs = [
    "What is Polarin DCI Wave?",
    "Is PAYG available for Ports?",
    "Is PAYG available for Virtual Router?",
    "Is virtual appliance available with PAYG plan?"
]

for msg in user_inputs:
    # Run pipeline
    response = pipeline.run(input=msg)
    print(f"Response of Pipeline: {response}")
    print(str(response))

The error you encountered (KeyError: 'text') suggests that the response from the vLLM server does not contain the expected 'text' key. Ensure that the vLLM server is returning the correct response format. You might need to adjust the get_response function in llama_index/llms/vllm/utils.py to handle the actual response format from your vLLM server.

If the issue persists, you might want to check the server logs or the response format to ensure it matches the expected structure.

To continue talking to Dosu, mention @dosu.