run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
35.81k stars 5.07k forks source link

[Question]: SQL query response has 'sql\n' from SQLAutoVectorQueryEngine #13999

Closed thin-amicus closed 1 week ago

thin-amicus commented 3 months ago

Question Validation

Question

I encountered an issue while using the SQLAutoVectorQueryEngine with GPT-4 and MS SQL database along with a vector tool. The issue is that when I query context with query_engine, SQL query is invalid issue because SQL query response includes 'sql\n' and the result cannot come out. But it does not appear all the time and is not stable when I use this.

Here's a snippet of the code:


sql_query_engine = SQLTableRetrieverQueryEngine(
    sql_database, obj_index.as_retriever(similarity_top_k=6)
)

sql_tool = QueryEngineTool.from_defaults(
    query_engine=sql_query_engine,
    description=(
        "Useful for translating a natural language query into a SQL query."
    ),
)

vector_tool = QueryEngineTool.from_defaults(
    query_engine=retriever_query_engine,
    description=(
        "Useful for answering semantic questions about different streets."
    ),
)

query_engine = SQLAutoVectorQueryEngine(
    sql_tool, vector_tool, llm=OpenAI(model=os.environ["GPT_VERSION"])
)

# Concatenate prompt and history
context = f"{history}\n{prompt}" if history else prompt
response = query_engine.query(context)
print(f"Query engine response: {response}")

# Extract the response message from the custom response object
response_message = response.response
response_metadata = response.metadata

# Result

Querying SQL database: The first choice is most relevant as it specifically addresses the translation of natural language queries into SQL queries for different types of property transactions, including sale transactions, which is directly related to the user's questions about transactions
SQL query: sql
SELECT Block, Street, Amount    
FROM transactions

SQL response: It appears there was an error in retrieving the data. The SQL query was invalid. I recommend checking the database schema or query syntax and trying again to ensure the correct information can be fetched. If you need further assistance, please let me know how I can help!
dosubot[bot] commented 3 months ago

Hey there, @thin-amicus! I'm here to assist you with any bugs, questions, or contributions while waiting for a human maintainer. Let's tackle this together!

To resolve the issue of 'sql\n' being included in the SQL query response, making it invalid, you can modify the parse_response_to_sql method in the DefaultSQLParser or PGVectorSQLParser class to ensure that any unwanted characters like 'sql\n' are removed from the response. Here is the relevant part of the code with the necessary modification:

class DefaultSQLParser(BaseSQLParser):
    """Default SQL Parser."""

    def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str:
        """Parse response to SQL."""
        sql_query_start = response.find("SQLQuery:")
        if sql_query_start != -1:
            response = response[sql_query_start:]
            # TODO: move to removeprefix after Python 3.9+
            if response.startswith("SQLQuery:"):
                response = response[len("SQLQuery:") :]
        sql_result_start = response.find("SQLResult:")
        if sql_result_start != -1:
            response = response[:sql_result_start]

        # Remove unwanted 'sql\n' from the response
        response = response.replace('sql\n', '').strip()

        return response.strip().strip("```").strip()

class PGVectorSQLParser(BaseSQLParser):
    """PGVector SQL Parser."""

    def __init__(self, embed_model: BaseEmbedding) -> None:
        """Initialize params."""
        self._embed_model = embed_model

    def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str:
        """Parse response to SQL."""
        sql_query_start = response.find("SQLQuery:")
        if sql_query_start != -1:
            response = response[sql_query_start:]
            # TODO: move to removeprefix after Python 3.9+
            if response.startswith("SQLQuery:"):
                response = response[len("SQLQuery:") :]
        sql_result_start = response.find("SQLResult:")
        if sql_result_start != -1:
            response = response[:sql_result_start]

        # Remove unwanted 'sql\n' from the response
        response = response.replace('sql\n', '').strip()

        # this gets you the sql string with [query_vector] placeholders
        raw_sql_str = response.strip().strip("```").strip()
        query_embedding = self._embed_model.get_query_embedding(query_bundle.query_str)
        query_embedding_str = str(query_embedding)
        return raw_sql_str.replace("[query_vector]", query_embedding_str)

Additionally, you can refer to the test_sql_index_query and test_sql_index_async_query functions in the llama-index-core/tests/indices/struct_store/test_sql_query.py file for examples of how SQL queries are tested and validated. These tests ensure that the SQL queries are correctly parsed and executed, and can help you verify that your modifications work as expected. For instance:

def test_sql_index_query(
    mock_service_context: ServiceContext,
    struct_kwargs: Tuple[Dict, Dict],
) -> None:
    """Test SQLStructStoreIndex."""
    # ... setup code ...
    sql_to_test = "SELECT user_id, foo FROM test_table"
    sql_query_engine = SQLStructStoreQueryEngine(index, **query_kwargs)
    response = sql_query_engine.query(sql_to_test)
    assert str(response) == "[(2, 'bar'), (8, 'hello')]"
    # ... more tests ...

This modification ensures that any occurrence of 'sql\n' in the response string is removed, thus preventing it from making the SQL query invalid [1][2][3].

To continue talking to Dosu, mention @dosu.

nick-youngblut commented 3 months ago

@thin-amicus I've reported a similar issue: https://github.com/run-llama/llama_index/issues/13491

Any ideas on how to fix it?