run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.32k stars 4.67k forks source link

[Bug]: Streaming Response doesn't work if verbose is on for SQLAutoVectorQueryEngine #13893

Open dhirajsuvarna opened 1 month ago

dhirajsuvarna commented 1 month ago

Bug Description

While using SQLAutoVectorQueryEngine, if I set verbose=True and provide it a vector_query_tool with streaming = True, the StreamingResponse returned doesn't stream.

On analysis I found that, the response_gen generator is getting consumed before its being returned by the query engine here

if self._verbose:
    print_text(f"Query Engine response: {response}\n", color="pink")

This is because the __str__() function of StreamingResponse is defined such that if the response_str is None the generator object is consumed to display the content.

    def __str__(self) -> str:
        """Convert to string representation."""
        if self.response_txt is None and self.response_gen is not None:
            response_txt = ""
            for text in self.response_gen:
                response_txt += text
            self.response_txt = response_txt
        return self.response_txt or "None"

Version

0.10.40

Steps to Reproduce

from dotenv import load_dotenv load_dotenv()

from llama_index.core import VectorStoreIndex, SimpleDirectoryReader

documents = SimpleDirectoryReader("data").load_data() index = VectorStoreIndex.from_documents(documents) query_engine = index.as_query_engine(streaming=True) response_stream = query_engine.query( "What did the author do growing up?", ) print(response_stream) for text in response_stream.response_gen: print(text)

Relevant Logs/Tracbacks

No output
dosubot[bot] commented 1 month ago

To fix the issue where the StreamingResponse returned by SQLAutoVectorQueryEngine with verbose=True and streaming=True doesn't stream because the response_gen generator is consumed before it's returned, you need to ensure that the generator is not consumed before it is returned. This can be done by avoiding any operations that iterate over the generator before returning it.

Here is a possible solution:

  1. Ensure that any logging or printing of the generator's content is done in a way that does not consume the generator.
  2. Return the generator directly without iterating over it.

Here is an example of how you might modify the _query_sql_other method to avoid consuming the generator:

def _query_sql_other(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
    """Query SQL database + other query engine in sequence."""
    # first query SQL database
    sql_response = self._sql_query_tool.query_engine.query(query_bundle)
    if not self._use_sql_join_synthesis:
        return sql_response

    sql_query = (
        sql_response.metadata["sql_query"] if sql_response.metadata else None
    )
    if self._verbose:
        print_text(f"SQL query: {sql_query}\n", color="yellow")
        print_text(f"SQL response: {sql_response}\n", color="yellow")

    # given SQL db, transform query into new query
    new_query = self._sql_augment_query_transform(
        query_bundle.query_str,
        metadata={
            "sql_query": _format_sql_query(sql_query),
            "sql_query_response": str(sql_response),
        },
    )

    if self._verbose:
        print_text(
            f"Transformed query given SQL response: {new_query.query_str}\n",
            color="blue",
        )
    logger.info(f"> Transformed query given SQL response: {new_query.query_str}")
    if self._sql_augment_query_transform.check_stop(new_query):
        return sql_response

    other_response = self._other_query_tool.query_engine.query(new_query)
    if self._verbose:
        print_text(f"query engine response: {other_response}\n", color="pink")
    logger.info(f"> query engine response: {other_response}")

    response_str = self._llm.predict(
        self._sql_join_synthesis_prompt,
        query_str=query_bundle.query_str,
        sql_query_str=sql_query,
        sql_response_str=str(sql_response),
        query_engine_query_str=new_query.query_str,
        query_engine_response_str=str(other_response),
    )
    if self._verbose:
        print_text(f"Final response: {response_str}\n", color="green")
    response_metadata = {
        **(sql_response.metadata or {}),
        **(other_response.metadata or {}),
    }
    source_nodes = other_response.source_nodes
    return Response(
        response_str,
        metadata=response_metadata,
        source_nodes=source_nodes,
    )

In this example, the generator response_gen is not consumed before it is returned, ensuring that the streaming functionality works as expected.

Additionally, ensure that you have properly configured both the LLM and the query engine to support streaming. According to the LlamaIndex documentation, you need to use an LLM that supports streaming and set streaming=True when building the query engine. Here is how you can set it up:

query_engine = index.as_query_engine(streaming=True, similarity_top_k=1)

Or, if using the low-level API:

from llama_index.core import get_response_synthesizer

synth = get_response_synthesizer(streaming=True, ...)
query_engine = RetrieverQueryEngine(response_synthesizer=synth, ...)

After configuring, calling query will return a StreamingResponse object, and you can iterate over the tokens as they arrive:

for text in streaming_response.response_gen:
    # do something with text as they arrive.
    pass

Alternatively, you can print the text as it arrives:

streaming_response.print_response_stream()

The StreamingResponse class in LlamaIndex is designed to handle streaming responses. It includes methods like __str__, get_response, and print_response_stream that ensure the generator is not consumed prematurely. Make sure to use these methods appropriately to maintain the streaming behavior. Here is the relevant part of the StreamingResponse class:

@dataclass
class StreamingResponse:
    """StreamingResponse object.

    Returned if streaming=True.

    Attributes:
        response_gen: The response generator.

    """

    response_gen: TokenGen
    source_nodes: List[NodeWithScore] = field(default_factory=list)
    metadata: Optional[Dict[str, Any]] = None
    response_txt: Optional[str] = None

    def __str__(self) -> str:
        """Convert to string representation."""
        if self.response_txt is None and self.response_gen is not None:
            response_txt = ""
            for text in self.response_gen:
                response_txt += text
            self.response_txt = response_txt
        return self.response_txt or "None"

    def get_response(self) -> Response:
        """Get a standard response object."""
        if self.response_txt is None and self.response_gen is not None:
            response_txt = ""
            for text in self.response_gen:
                response_txt += text
            self.response_txt = response_txt
        return Response(self.response_txt, self.source_nodes, self.metadata)

    def print_response_stream(self) -> None:
        """Print the response stream."""
        if self.response_txt is None and self.response_gen is not None:
            response_txt = ""
            for text in self.response_gen:
                print(text, end="", flush=True)
                response_txt += text
            self.response_txt = response_txt
        else:
            print(self.response_txt)

Additionally, the StreamingGeneratorCallbackHandler class in LlamaIndex ensures that tokens are streamed correctly. Here is the relevant part of the class:

from queue import Queue
from threading import Event
from typing import Any, Generator, List, Optional
from uuid import UUID

from llama_index.core.bridge.langchain import BaseCallbackHandler, LLMResult

class StreamingGeneratorCallbackHandler(BaseCallbackHandler):
    """Streaming callback handler."""

    def __init__(self) -> None:
        self._token_queue: Queue = Queue()
        self._done = Event()

    def __deepcopy__(self, memo: Any) -> "StreamingGeneratorCallbackHandler":
        # NOTE: hack to bypass deepcopy in langchain
        return self

    def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
        """Run on new LLM token. Only available when streaming is enabled."""
        self._token_queue.put_nowait(token)

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        self._done.set()

    def on_llm_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:
        self._done.set()

    def get_response_gen(self) -> Generator:
        while True:
            if not self._token_queue.empty():
                token = self._token_queue.get_nowait()
                yield token
            elif self._done.is_set():
                break

By following these guidelines, you can ensure that the StreamingResponse generator is not consumed prematurely, allowing the streaming functionality to work as intended [1][2][3][4][5].

To continue talking to Dosu, mention @dosu.

dhirajsuvarna commented 1 month ago

Need some human support.