eunja511005 / AutoCoding

0 stars 0 forks source link

Create SQL query from chat input #192

Open eunja511005 opened 2 months ago

eunja511005 commented 2 months ago
import urllib.request
import sqlite3
from langchain_community.utilities.sql_database import SQLDatabase
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.chat_models.ollama import ChatOllama

urllib.request.urlretrieve(
    "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql",
    filename="Chinook_Sqlite.sql",
)

# 데이터베이스 연결 생성 (데이터베이스가 없으면 새로 만듭니다)
conn = sqlite3.connect('Chinook.db')

with open("Chinook_Sqlite.sql", 'r', encoding='utf-8') as file:
    script = file.read()

# 스크립트 실행
conn.executescript(script)

conn.close()

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# 인자가 없으면 에러가 납니다.
def get_schema(_):

    print("###")
    print(_)
    print("in get_schema")
    print("###")

    return db.get_table_info()

response = db.run("SELECT COUNT(EmployeeId) AS TotalEmployees\nFROM Employee;")

print(response)

# template = """Based on the table schema below,
# write a SQL query that would answer the user's question:

# {schema}

# Question: {question}
# SQL Query:"""
template ="""
Given a database schema and a question, analyze the question to identify relevant tables and fields from the schema. 
Then, construct an SQL query that can be executed in sqlite3 to retrieve the answer to the question.

# Schema: {schema}
# Question: {question}
# Answer: SQL Query"""

prompt_template = ChatPromptTemplate.from_template(template)

chat_model = ChatOllama(model="gemma:2b")

sql_gen_chain = (
        RunnablePassthrough.assign(schema = get_schema) # 기존 dict에다가 schema 키와 값을 추가한다.
        | prompt_template
        | chat_model
        | StrOutputParser()
    )

res = sql_gen_chain.invoke({"question": "How many employees are there?"})

print(res)
eunja511005 commented 2 months ago
!pip install langchain
import os
from google.colab import userdata

hf_token = userdata.get('HF_TOKEN')

print(hf_token)