vanna-ai / vanna

πŸ€– Chat with your SQL database πŸ“Š. Accurate Text-to-SQL Generation via LLMs using RAG πŸ”„.
https://vanna.ai/docs/
MIT License
11.48k stars 919 forks source link

[BUG] Does not support the recognition of SQL query statements with parentheses. #687

Open kellan04 opened 1 week ago

kellan04 commented 1 week ago

When parentheses appear at the beginning and end of the SQL statement generated by the large model, it is impossible to extract the complete SQL using regular expressions.

Scenario: There are two data tables, and the query is intended to perform a cross-table query. An example of the generated SQL is: (SELECT xxx) UNIT (SELECT xxx);

The source code does not support recognition: src/vanna/base/base.py image

My modified version:

def extract_sql(self, llm_response: str) -> str:
        """
        Example:
        ```python
        vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
    Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
    Override this function if your LLM responses need custom extraction logic.

    Args:
        llm_response (str): The LLM response.

    Returns:
        str: The extracted SQL query.
    """

    # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
    sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
    if sqls:
        sql = sqls[-1]
        self.log(title="Extracted SQL", message=f"{sql}")
        return sql

    # If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
    pattern = r"(\(?\s*SELECT\s+.*?\s*\)*?;)"   ### εŒΉι…εΈ¦ζœ‰ζˆ–δΈεΈ¦ζ‹¬ε·ηš„ SELECT θ―­ε₯
    # pattern = r"SELECT.*?;"
    sqls = re.findall(pattern, llm_response, re.DOTALL)
    if sqls:
        sql = sqls[-1]
        self.log(title="Extracted SQL", message=f"{sql}")
        return sql

    # If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
    sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
    if sqls:
        sql = sqls[-1]
        self.log(title="Extracted SQL", message=f"{sql}")
        return sql

    sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
    if sqls:
        sql = sqls[-1]
        self.log(title="Extracted SQL", message=f"{sql}")
        return sql

    return llm_response
svetozar02 commented 6 days ago

I'm running into the same issue. The leading ( character is getting removed during SQL extraction.

LLM Response: 
        (
            SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
Extracted SQL: SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
An error occurred while executing SQL: syntax error at or near ")"
kellan04 commented 6 days ago

I'm running into the same issue. The leading ( character is getting removed during SQL extraction.

LLM Response: 
        (
            SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
Extracted SQL: SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
An error occurred while executing SQL: syntax error at or near ")"

yes, that is why I use the new regex.

svetozar02 commented 5 days ago

@kellan04 yup, your regex worked and I'm using it.