run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
36.37k stars 5.19k forks source link

[Question]: SQLTableRetrieverQueryEngine does not support streaming in async mode #15114

Open nazim-ashman-oc opened 2 months ago

nazim-ashman-oc commented 2 months ago

Question Validation

Question

This is the BaseSQLTableQueryEngine which some other engines inherit from

class BaseSQLTableQueryEngine(BaseQueryEngine):
    """Base SQL Table query engine.

    NOTE: Any Text-to-SQL application should be aware that executing
    arbitrary SQL queries can be a security risk. It is recommended to
    take precautions as needed, such as using restricted roles, read-only
    databases, sandboxing, etc.
    """

    def __init__(
        ....

    def _query(self, query_bundle: QueryBundle) -> Response:
        """Answer a query."""
        retrieved_nodes, metadata = self.sql_retriever.retrieve_with_metadata(
            query_bundle
        )

        sql_query_str = metadata["sql_query"]
        if self._synthesize_response:
            partial_synthesis_prompt = self._response_synthesis_prompt.partial_format(
                sql_query=sql_query_str,
            )
            response_synthesizer = get_response_synthesizer(
                llm=self._llm,
                callback_manager=self.callback_manager,
                text_qa_template=partial_synthesis_prompt,
                refine_template=self._refine_synthesis_prompt,
                verbose=self._verbose,
                streaming=self._streaming,
            )
            response = response_synthesizer.synthesize(
                query=query_bundle.query_str,
                nodes=retrieved_nodes,
            )
            cast(Dict, response.metadata).update(metadata)
            return cast(Response, response)
        else:
            response_str = "\n".join([node.node.text for node in retrieved_nodes])
            return Response(response=response_str, metadata=metadata)

    async def _aquery(self, query_bundle: QueryBundle) -> Response:
        """Answer a query."""
        retrieved_nodes, metadata = await self.sql_retriever.aretrieve_with_metadata(
            query_bundle
        )

        sql_query_str = metadata["sql_query"]
        if self._synthesize_response:
            partial_synthesis_prompt = self._response_synthesis_prompt.partial_format(
                sql_query=sql_query_str,
            )

            response_synthesizer = get_response_synthesizer(
                llm=self._llm,
                callback_manager=self.callback_manager,
                text_qa_template=partial_synthesis_prompt,
                refine_template=self._refine_synthesis_prompt,
            )
            response = await response_synthesizer.asynthesize(
                query=query_bundle.query_str,
                nodes=retrieved_nodes,
            )
            cast(Dict, response.metadata).update(metadata)
            return cast(Response, response)
        else:
            response_str = "\n".join([node.node.text for node in retrieved_nodes])
            return Response(response=response_str, metadata=metadata)

I notice when run in async mode (calling aquery instead of query) the streaming mode is not passed on into the response synthesizer causing stuff to break. Adding the streaming gets it to work.

Also, if we are streaming, why are is it casting to Response. (return cast(Response, response)). Shouldn't it be to either a StreamingResponse or an AsyncStreamingResponse?

dosubot[bot] commented 2 months ago

To address the issue of streaming mode not being passed to the response synthesizer in the BaseSQLTableQueryEngine class's _aquery method, you can modify the method to include the streaming parameter. Additionally, you should handle the response type correctly when streaming is enabled. Here is the updated _aquery method:

async def _aquery(self, query_bundle: QueryBundle) -> Response:
    """Answer a query."""
    retrieved_nodes, metadata = await self.sql_retriever.aretrieve_with_metadata(
        query_bundle
    )

    sql_query_str = metadata["sql_query"]
    if self._synthesize_response:
        partial_synthesis_prompt = self._response_synthesis_prompt.partial_format(
            sql_query=sql_query_str,
        )

        response_synthesizer = get_response_synthesizer(
            llm=self._llm,
            callback_manager=self.callback_manager,
            text_qa_template=partial_synthesis_prompt,
            refine_template=self._refine_synthesis_prompt,
            streaming=self._streaming,
        )
        response = await response_synthesizer.asynthesize(
            query=query_bundle.query_str,
            nodes=retrieved_nodes,
        )
        cast(Dict, response.metadata).update(metadata)
        if self._streaming:
            return cast(AsyncStreamingResponse, response)
        return cast(Response, response)
    else:
        response_str = "\n".join([node.node.text for node in retrieved_nodes])
        return Response(response=response_str, metadata=metadata)

In this updated method:

This should resolve the issue of streaming mode not being passed and ensure the correct response type is returned.

To continue talking to Dosu, mention @dosu.

nazim-ashman-oc commented 2 months ago

@dosu Should I make a PR or something then?

dosubot[bot] commented 2 months ago

Yes, creating a pull request with these changes would be a great next step. This will allow the maintainers to review and merge your improvements into the repository. Thank you for your contribution!

To continue talking to Dosu, mention @dosu.

logan-markewich commented 2 months ago

@nazim-ashman-oc would love a PR if you have time

nazim-ashman-oc commented 2 months ago

@nazim-ashman-oc would love a PR if you have time

@logan-markewich https://github.com/run-llama/llama_index/pull/15129

First PR to an open source thing, hope its ok.