vanna-ai / vanna

🤖 Chat with your SQL database 📊. Accurate Text-to-SQL Generation via LLMs using RAG 🔄.
https://vanna.ai/docs/
MIT License
9.41k stars 691 forks source link

something wrong with `question_sql_list` #451

Closed peilongchencc closed 1 month ago

peilongchencc commented 1 month ago

Describe the bug

vanna version: 0.5.5

something wrong with question_sql_list, I suspect there might be an issue with the usage of question_sql_list. Every time a search is conducted, question_sql_list is added to the historical conversation, which unnecessarily consumes my GPT tokens. Could you please check this?

For example, when I asked a few questions, the resulting historical conversation is as follows:

[
    {
        "role": "system",
        "content": "You are a SQL expert. Please help to generate a SQL query to ..."
    },
    {
        "role": "user",
        "content": "Are there any branches in Shanghai?"
    },
    {
        "role": "assistant",
        "content": "SELECT DISTINCT address\nFROM irmdata.branch_information\nWHERE address LIKE '%Shanghai%';"
    },
    {
        "role": "user",
        "content": "How many branches are there in total?"
    },
    {
        "role": "assistant",
        "content": "SELECT COUNT(*) AS total_branches\nFROM irmdata.branch_information;"
    },
    {
        "role": "user",
        "content": "Are there any branches in Beijing?"
    },
    {
        "role": "assistant",
        "content": "SELECT DISTINCT address\nFROM irmdata.branch_information\nWHERE address LIKE '%Beijing%';"
    },
    {
        "role": "user",
        "content": "Check how many customers there are in total."
    },
    {
        "role": "assistant",
        "content": "SELECT COUNT(*) AS total_customers\nFROM irmdata.customer_information;"
    },
    {
        "role": "user",
        "content": "How many employees are there in total?"
    },
    {
        "role": "assistant",
        "content": "SELECT COUNT(*) AS total_employees\nFROM irmdata.employee_information;"
    },
    {
        "role": "user",
        "content": "Are there any branches in Shanghai?"
    }
]

However, I am just using vn_rtn = vn.ask(question="Are there any branches in Shanghai?", visualize=False)..

This is just a single-turn conversation, so why is it being constructed as a historical conversation?

To Reproduce

My complete code is as follows, I just change the question each time, you can run the code to reproduce the result:

import os
from vanna.openai import OpenAI_Chat
from vanna.chromadb import ChromaDB_VectorStore
from loguru import logger
from dotenv import load_dotenv
import pandas as pd

load_dotenv('env_config/.env.local')

# Set up logging
logger.remove()
logger.add("chromadb_test.log", rotation="1 GB", backtrace=True, diagnose=True, format="{time} {level} {message}")

def log_dataframe(df, filename):
    """To avoid ellipses in the log, save as CSV."""
    # Ensure the logs directory exists
    os.makedirs('logs', exist_ok=True)
    # Save the DataFrame as a CSV file
    path = f"./logs/{filename}"
    df.to_csv(path, index=False)
    # Log the file location
    logger.info(f"DataFrame saved to {path}")

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

api_key = os.getenv("OPENAI_API_KEY")

vn = MyVanna(config={'api_key': api_key, 'model': 'gpt-4o'})
vn.connect_to_mysql(host='localhost', dbname='irmdata', user='root', password='Flameaway3.', port=3306)

# Get metadata from the database in MySQL (here using a self-built irmdata) including column names, data types, default values, and comments. (Not the data itself. To get the data itself, use a SELECT statement)
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = 'irmdata'")
log_dataframe(df_information_schema, "df_information_schema.csv")

plan = vn.get_training_plan_generic(df_information_schema)
logger.info(f"Plan is:\n{plan}, {type(plan)}")

plan_summary = plan.get_summary()
logger.info(f"Plan summary result is:\n{plan_summary}")

# If you like the plan, then uncomment this and run it to train
vn.train(plan=plan)
logger.info(f"vn after training is:\n{vn}")

# Print the training data
training_data = vn.get_training_data()
# logger.info(f"Training data is:\n{training_data}")
log_dataframe(training_data, "training_data_new.csv")

vn_rtn = vn.ask(question="Are there branches in Shanghai?", visualize=False)
logger.info(f"Query result is:\n{vn_rtn}")

Expected behavior

a single-turn conversation.

zainhoda commented 1 month ago

This is intentional -- the question-SQL pairs as user/assistant messages is in-context learning and how the LLM will know to follow the pattern and generate SQL correctly.

You can dial down the number of those that are used by setting the n_results_sql config: https://github.com/vanna-ai/vanna/blob/a72b842d420cf1fa061e5f97d45ea08051651ebb/src/vanna/chromadb/chromadb_vector.py#L25

peilongchencc commented 1 month ago

Ok, I thought you wanted to use question_sql_list in initial_prompt.