I am using below code to generate the queries but for particular cases it generates queries with incorrect syntax, it does mistakes like - sometimes it add one extra bracket ')' or may be sometimes it misses to add bracket or other small things like comma or so on,
How to fix it, for reference below is the code i am using -
from langchain_community.utilities import SQLDatabase
from langchain import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
from langchain_experimental.chat_models import Llama2Chat
from langchain.llms import LlamaCpp
from dotenv import load_dotenv
import os
import sqlglot
from langchain.chains import LLMChain
class SQLQueryGenerator:
def init(self):
self.load_env_variables()
self.setup_database()
self.setup_llm()
def load_env_variables(self):
load_dotenv()
self.db_uri = os.getenv('DB_CONNECTION_STRING')
self.model_path = os.getenv('MODEL_PATH')
self.temperature = float(os.getenv('TEMPERATURE'))
self.max_tokens = int(os.getenv('MAX_TOKENS'))
self.top_p = float(os.getenv('TOP_P'))
self.n_ctx = int(os.getenv('N_CTX'))
self.table_names = os.getenv('TABLES_NAME').split(',')
self.second_model_path = os.getenv('SECOND_MODEL_PATH')
def setup_database(self):
self.db = SQLDatabase.from_uri(self.db_uri, include_tables=self.table_names, view_support=True)
print(self.db.dialect)
def setup_llm(self):
self.llm = LlamaCpp(
model_path=self.model_path,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_p=self.top_p,
verbose=True,
n_ctx=self.n_ctx
)
def create_template(self, question):
return f"""### Instructions:
Your task is to convert a question into a SQL query, given a MySQL database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Strictly Do not use 'ilike' instead use 'like' for the queries where it is required
- **We have a MySQL database so generate the query as per MySQL syntax only** not as per other SQL databases like Postgres, Oracle, and other SQL databases.
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
- **Also make sure that generated query is syntactically correct and will produce correct results on execution.
### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
CREATE TABLE users (
id int NOT NULL AUTO_INCREMENT,
first_name varchar(50) NOT NULL,
last_name varchar(50) NOT NULL,
dob_day int NOT NULL,
dob_month int NOT NULL,
dob_year int NOT NULL,
gender enum('male','female','other') NOT NULL,
email varchar(100) NOT NULL,
username varchar(100) NOT NULL,
password varchar(255) NOT NULL,
mobile varchar(15) NOT NULL,
PRIMARY KEY ('id'),
UNIQUE KEY 'email' ('email'),
UNIQUE KEY 'username' ('username')
);
CREATE TABLE accounts (
account_number INT NOT NULL AUTO_INCREMENT,
user_id INT NOT NULL,
balance DECIMAL(10, 2) NOT NULL DEFAULT 0,
PRIMARY KEY (account_number),
FOREIGN KEY (user_id) REFERENCES users(id) );
Examples -
SQL examples:
# Example 1:
User Question:
how many total users present ?
SQL Query:
Select count(*) as totalUsers from users
Result:
+--------------+
| totalUsers |
+--------------+
| 21 |
+--------------+
# Example 2:
User Question:
what is balance of user whose first name is newtest
SQL Query:
SELECT accounts.user_id, accounts.balance FROM users JOIN accounts ON users.id = accounts.user_id WHERE users.first_name like '%newtest%';
Result:
+--------------+-----------------+
| user_id | balance |
+--------------+-----------------+
| 21 | 4291.59 |
+--------------+-----------------+
### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
sql
"""
def generate_sql_query(self, question):
prompt = PromptTemplate(template=self.create_template(question), input_variables=["question"])
sql_chain = SQLDatabaseChain.from_llm(prompt=prompt, llm=self.llm, db=self.db, verbose=True, top_k=10)
try:
generated_query = sql_chain.invoke(question)
if generated_query:
parsed_query = sqlglot.parse_one(generated_query.get("result"), dialect="postgres").sql(dialect="mysql")
return parsed_query
except Exception as e:
print(e)
if name == "main":
sql_generator = SQLQueryGenerator()
user_question = input("Please enter your question: ")
result = sql_generator.generate_sql_query(user_question)
if result:
print("Generated SQL Query:", result)
else:
print("An error occurred during SQL query generation.")
my requirements.txt content -NOTE - there might be some extra dependencies listed in requirements.txt
I am using below code to generate the queries but for particular cases it generates queries with incorrect syntax, it does mistakes like - sometimes it add one extra bracket ')' or may be sometimes it misses to add bracket or other small things like comma or so on,
How to fix it, for reference below is the code i am using -
from langchain_community.utilities import SQLDatabase from langchain import PromptTemplate from langchain_experimental.sql import SQLDatabaseChain from langchain_experimental.chat_models import Llama2Chat from langchain.llms import LlamaCpp from dotenv import load_dotenv import os import sqlglot from langchain.chains import LLMChain
class SQLQueryGenerator: def init(self): self.load_env_variables() self.setup_database() self.setup_llm()
if name == "main": sql_generator = SQLQueryGenerator() user_question = input("Please enter your question: ") result = sql_generator.generate_sql_query(user_question) if result: print("Generated SQL Query:", result) else: print("An error occurred during SQL query generation.")
my requirements.txt content - NOTE - there might be some extra dependencies listed in requirements.txt
pandas~=2.2.0 sqlalchemy~=2.0.25 pymysql load_dotenv chainlit~=1.0.200 langchain~=0.1.20 python-dotenv~=1.0.1 langchain_experimental huggingface_hub transformers auto_gptq huggingface-hub python-dotenv flask boto3 langchain-aws==0.1.3
Also I am using below version of SQLCODER -
sqlcoder-7b-q5_k_m.gguf
I have tried a lot, please help how to fix it ? Please do let know if something else is required from myside which could help to resolve the issue.