import urllib.request
import sqlite3
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.chat_models.ollama import ChatOllama
urllib.request.urlretrieve(
"https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql",
filename="Chinook_Sqlite.sql",
)
# 데이터베이스 연결 생성 (데이터베이스가 없으면 새로 만듭니다)
conn = sqlite3.connect('Chinook.db')
with open("Chinook_Sqlite.sql", 'r', encoding='utf-8') as file:
script = file.read()
# 스크립트 실행
conn.executescript(script)
conn.close()
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# 인자가 없으면 에러가 납니다.
def get_schema(_):
print("###")
print(_)
print("in get_schema")
print("###")
return db.get_table_info()
response = db.run("SELECT COUNT(EmployeeId) AS TotalEmployees\nFROM Employee;")
print(response)
# template = """Based on the table schema below,
# write a SQL query that would answer the user's question:
# {schema}
# Question: {question}
# SQL Query:"""
template ="""
Given a database schema and a question, analyze the question to identify relevant tables and fields from the schema.
Then, construct an SQL query that can be executed in sqlite3 to retrieve the answer to the question.
# Schema: {schema}
# Question: {question}
# Answer: SQL Query"""
prompt_template = ChatPromptTemplate.from_template(template)
chat_model = ChatOllama(model="gemma:2b")
sql_gen_chain = (
RunnablePassthrough.assign(schema = get_schema) # 기존 dict에다가 schema 키와 값을 추가한다.
| prompt_template
| chat_model
| StrOutputParser()
)
res = sql_gen_chain.invoke({"question": "How many employees are there?"})
print(res)